From c616423676c671609d90707cab34a0a4b0a5c6cc Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Thu, 7 Nov 2024 22:21:42 -0800 Subject: [PATCH] Fixed bug that results in an incorrect type evaluation when a `match` statement uses a pattern with a target expression that overwrites the subject expression. This addresses #9418. (#9428) --- .../pyright-internal/src/analyzer/binder.ts | 26 +++++++++++++- .../src/analyzer/patternMatching.ts | 7 +++- .../src/tests/samples/matchClass7.py | 34 +++++++++++++++++++ .../src/tests/typeEvaluator6.test.ts | 8 +++++ 4 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 packages/pyright-internal/src/tests/samples/matchClass7.py diff --git a/packages/pyright-internal/src/analyzer/binder.ts b/packages/pyright-internal/src/analyzer/binder.ts index 4825323a6f3e..6c2a81c029a5 100644 --- a/packages/pyright-internal/src/analyzer/binder.ts +++ b/packages/pyright-internal/src/analyzer/binder.ts @@ -201,6 +201,10 @@ export class Binder extends ParseTreeWalker { // and require code flow analysis to resolve. private _currentScopeCodeFlowExpressions: Set | undefined; + // If we're actively binding a match statement, this is the current + // match expression. + private _currentMatchSubjExpr: ExpressionNode | undefined; + // Aliases of "typing" and "typing_extensions". private _typingImportAliases: string[] = []; @@ -2275,10 +2279,20 @@ export class Binder extends ParseTreeWalker { this._currentFlowNode = this._finishFlowLabel(preGuardLabel); + // Note the active match subject expression prior to binding + // the pattern. If the pattern involves any targets that overwrite + // the subject expression, this will be set to undefined. + this._currentMatchSubjExpr = node.d.expr; + // Bind the pattern. this.walk(caseStatement.d.pattern); - this._createFlowNarrowForPattern(node.d.expr, caseStatement); + // If the pattern involves targets that overwrite the subject + // expression, skip creating a flow node for narrowing the subject. + if (this._currentMatchSubjExpr) { + this._createFlowNarrowForPattern(node.d.expr, caseStatement); + this._currentMatchSubjExpr = undefined; + } // Apply the guard expression. if (caseStatement.d.guardExpr) { @@ -2465,6 +2479,16 @@ export class Binder extends ParseTreeWalker { const symbol = this._bindNameToScope(this._currentScope, target); this._createAssignmentTargetFlowNodes(target, /* walkTargets */ false, /* unbound */ false); + // See if the target overwrites all or a portion of the subject expression. + if (this._currentMatchSubjExpr) { + if ( + ParseTreeUtils.isMatchingExpression(target, this._currentMatchSubjExpr) || + ParseTreeUtils.isPartialMatchingExpression(target, this._currentMatchSubjExpr) + ) { + this._currentMatchSubjExpr = undefined; + } + } + if (symbol) { const declaration: VariableDeclaration = { type: DeclarationType.Variable, diff --git a/packages/pyright-internal/src/analyzer/patternMatching.ts b/packages/pyright-internal/src/analyzer/patternMatching.ts index 3dd371077daf..ceaa1eea6c06 100644 --- a/packages/pyright-internal/src/analyzer/patternMatching.ts +++ b/packages/pyright-internal/src/analyzer/patternMatching.ts @@ -859,7 +859,8 @@ function narrowTypeBasedOnClassPattern( LocAddendum.typeNotClass().format({ type: evaluator.printType(exprType) }), pattern.d.className ); - return NeverType.createNever(); + + return isPositiveTest ? UnknownType.create() : type; } else if (isInstantiableClass(exprType)) { if (ClassType.isProtocolClass(exprType) && !ClassType.isRuntimeCheckable(exprType)) { evaluator.addDiagnostic( @@ -867,12 +868,16 @@ function narrowTypeBasedOnClassPattern( LocAddendum.protocolRequiresRuntimeCheckable(), pattern.d.className ); + + return isPositiveTest ? UnknownType.create() : type; } else if (ClassType.isTypedDictClass(exprType)) { evaluator.addDiagnostic( DiagnosticRule.reportGeneralTypeIssues, LocMessage.typedDictInClassPattern(), pattern.d.className ); + + return isPositiveTest ? UnknownType.create() : type; } } diff --git a/packages/pyright-internal/src/tests/samples/matchClass7.py b/packages/pyright-internal/src/tests/samples/matchClass7.py new file mode 100644 index 000000000000..c2f4075650d9 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/matchClass7.py @@ -0,0 +1,34 @@ +# This sample tests the case where a class pattern overwrites the subject +# expression. + +from dataclasses import dataclass + + +@dataclass +class DC1: + val: str + + +def func1(val: DC1): + result = val + + match result: + case DC1(result): + reveal_type(result, expected_text="str") + + +@dataclass +class DC2: + val: DC1 + + +def func2(val: DC2): + result = val + + match result.val: + case DC1(result): + reveal_type(result, expected_text="str") + + # This should generate an error because result.val + # is no longer valid at this point. + print(result.val) diff --git a/packages/pyright-internal/src/tests/typeEvaluator6.test.ts b/packages/pyright-internal/src/tests/typeEvaluator6.test.ts index 1e89799dd666..b5c8c98db9ca 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator6.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator6.test.ts @@ -520,6 +520,14 @@ test('MatchClass6', () => { TestUtils.validateResults(analysisResults, 0); }); +test('MatchClass7', () => { + const configOptions = new ConfigOptions(Uri.empty()); + + configOptions.defaultPythonVersion = pythonVersion3_10; + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['matchClass7.py'], configOptions); + TestUtils.validateResults(analysisResults, 1); +}); + test('MatchValue1', () => { const configOptions = new ConfigOptions(Uri.empty());