diff --git a/packages/pyright-internal/src/analyzer/parameterUtils.ts b/packages/pyright-internal/src/analyzer/parameterUtils.ts index 7571dcf39c8d..53caf7b475d5 100644 --- a/packages/pyright-internal/src/analyzer/parameterUtils.ts +++ b/packages/pyright-internal/src/analyzer/parameterUtils.ts @@ -39,6 +39,7 @@ export enum ParamKind { Positional, Standard, Keyword, + ExpandedArgs, } export interface VirtualParamDetails { @@ -185,7 +186,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails { index, tupleArg.type, /* defaultArgTypeOverride */ undefined, - ParamKind.Positional + ParamKind.ExpandedArgs ); if (category === ParamCategory.Simple) { diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 35e899752b26..d3b775bd13b2 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -25578,13 +25578,14 @@ export function createTypeEvaluator( const destParamName = destParam.param.name ?? ''; const srcParamName = srcParam.param.name ?? ''; if (destParamName) { - const isDestPositionalOnly = destParam.kind === ParamKind.Positional; + const isDestPositionalOnly = + destParam.kind === ParamKind.Positional || destParam.kind === ParamKind.ExpandedArgs; if ( !isDestPositionalOnly && destParam.param.category !== ParamCategory.ArgsList && srcParam.param.category !== ParamCategory.ArgsList ) { - if (srcParam.kind === ParamKind.Positional) { + if (srcParam.kind === ParamKind.Positional || srcParam.kind === ParamKind.ExpandedArgs) { diag?.createAddendum().addMessage( LocAddendum.functionParamPositionOnly().format({ name: destParamName, diff --git a/packages/pyright-internal/src/tests/samples/typeVarTuple14.py b/packages/pyright-internal/src/tests/samples/typeVarTuple14.py index e092e2d203d7..5e26314b7275 100644 --- a/packages/pyright-internal/src/tests/samples/typeVarTuple14.py +++ b/packages/pyright-internal/src/tests/samples/typeVarTuple14.py @@ -15,12 +15,10 @@ def call_with_params(func: Callable[[*Ts], R], *params: *Ts) -> R: return func(*params) -def callback1(*args: int) -> int: - ... +def callback1(*args: int) -> int: ... -def callback2(*args: *tuple[int, int]) -> int: - ... +def callback2(*args: *tuple[int, int]) -> int: ... call_with_params(callback1) @@ -38,8 +36,7 @@ def callback2(*args: *tuple[int, int]) -> int: call_with_params(callback2, 1, "") -def callback3(*args: *tuple[int, *tuple[str, ...], int]) -> int: - ... +def callback3(*args: *tuple[int, *tuple[str, ...], int]) -> int: ... # This should generate an error. @@ -55,23 +52,58 @@ def callback3(*args: *tuple[int, *tuple[str, ...], int]) -> int: call_with_params(callback3, 1, 1, 2) -class Foo: +class ClassA: @classmethod - def foo(cls, *shape: *Ts) -> tuple[*Ts]: - ... + def method1(cls, *shape: *Ts) -> tuple[*Ts]: ... -def call_with_params2(target: Callable[[*Ts], int]) -> tuple[*Ts]: - ... +def func1(target: Callable[[*Ts], int]) -> tuple[*Ts]: ... -def callback4(a: int, b: str, /) -> int: - ... +def func2(a: int, b: str, /) -> int: ... -def g(action: Callable[[int, str], int]): - v1 = call_with_params2(callback4) +def func3(action: Callable[[int, str], int]): + v1 = func1(func2) reveal_type(v1, expected_text="tuple[int, str]") - v2 = call_with_params2(action) + v2 = func1(action) reveal_type(v2, expected_text="tuple[int, str]") + + +def func4(*args: *tuple[int, str]): ... + + +func4(1, "") + +# This should generate an error. +func4() + +# This should generate an error. +func4(1) + +# This should generate an error. +func4(1, "", "") + + +def func5(*args: *tuple[int, *tuple[str, ...], int]): ... + + +func5(1, 1) +func5(1, "", 1) +func5(1, "", "", 1) + +# This should generate an error. +func5() + +# This should generate an error. +func5(1) + +# This should generate an error. +func5("") + +# This should generate an error. +func5(1, "") + +# This should generate an error. +func5(1, "", "") diff --git a/packages/pyright-internal/src/tests/typeEvaluator6.test.ts b/packages/pyright-internal/src/tests/typeEvaluator6.test.ts index 429f978b5630..6fb3c7b4f888 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator6.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator6.test.ts @@ -296,7 +296,7 @@ test('TypeVarTuple14', () => { configOptions.defaultPythonVersion = pythonVersion3_11; const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeVarTuple14.py'], configOptions); - TestUtils.validateResults(analysisResults, 6); + TestUtils.validateResults(analysisResults, 14); }); test('TypeVarTuple15', () => {