From 54b5d0b4f832ea18e0fa34661d3eb3f8f1b7424b Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Thu, 7 Nov 2024 13:34:11 -0800 Subject: [PATCH] Added check for the illegal use of an `await` keyword in a lambda. This addresses #9406. --- packages/pyright-internal/src/analyzer/binder.ts | 6 +++--- .../pyright-internal/src/tests/samples/coroutines1.py | 9 +++++++++ .../pyright-internal/src/tests/typeEvaluator3.test.ts | 2 +- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/binder.ts b/packages/pyright-internal/src/analyzer/binder.ts index 49eb10d7c899..97bbe26fb54e 100644 --- a/packages/pyright-internal/src/analyzer/binder.ts +++ b/packages/pyright-internal/src/analyzer/binder.ts @@ -1600,9 +1600,9 @@ export class Binder extends ParseTreeWalker { override visitAwait(node: AwaitNode) { // Make sure this is within an async lambda or function. - const enclosingFunction = ParseTreeUtils.getEnclosingFunction(node); - if (enclosingFunction === undefined || !enclosingFunction.d.isAsync) { - if (this._fileInfo.ipythonMode && enclosingFunction === undefined) { + const execScopeNode = ParseTreeUtils.getExecutionScopeNode(node); + if (execScopeNode?.nodeType !== ParseNodeType.Function || !execScopeNode.d.isAsync) { + if (this._fileInfo.ipythonMode && execScopeNode?.nodeType === ParseNodeType.Module) { // Top level await is allowed in ipython mode. return true; } diff --git a/packages/pyright-internal/src/tests/samples/coroutines1.py b/packages/pyright-internal/src/tests/samples/coroutines1.py index 382a871b6ac3..8e0af8bcfe35 100644 --- a/packages/pyright-internal/src/tests/samples/coroutines1.py +++ b/packages/pyright-internal/src/tests/samples/coroutines1.py @@ -16,6 +16,15 @@ async def coroutine1(): await a +async def func1() -> int: ... + + +async def func2() -> None: + # This should generate an error because await cannot be + # used in a lambda. + x = lambda: await func2() + + def needs_int(val: int): pass diff --git a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts index c6f2890d2d58..001d9cca2722 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts @@ -159,7 +159,7 @@ test('Coroutines1', () => { configOptions.defaultPythonVersion = pythonVersion3_10; const analysisResults = TestUtils.typeAnalyzeSampleFiles(['coroutines1.py'], configOptions); - TestUtils.validateResults(analysisResults, 4); + TestUtils.validateResults(analysisResults, 5); }); test('Coroutines2', () => {