Skip to content

Commit

Permalink
Fixed bug that results in an incorrect type evaluation when a match
Browse files Browse the repository at this point in the history
… statement uses a pattern with a target expression that overwrites the subject expression. This addresses #9418. (#9428)
  • Loading branch information
erictraut authored Nov 8, 2024
1 parent 9d60c43 commit c616423
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 2 deletions.
26 changes: 25 additions & 1 deletion packages/pyright-internal/src/analyzer/binder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ export class Binder extends ParseTreeWalker {
// and require code flow analysis to resolve.
private _currentScopeCodeFlowExpressions: Set<string> | undefined;

// If we're actively binding a match statement, this is the current
// match expression.
private _currentMatchSubjExpr: ExpressionNode | undefined;

// Aliases of "typing" and "typing_extensions".
private _typingImportAliases: string[] = [];

Expand Down Expand Up @@ -2275,10 +2279,20 @@ export class Binder extends ParseTreeWalker {

this._currentFlowNode = this._finishFlowLabel(preGuardLabel);

// Note the active match subject expression prior to binding
// the pattern. If the pattern involves any targets that overwrite
// the subject expression, this will be set to undefined.
this._currentMatchSubjExpr = node.d.expr;

// Bind the pattern.
this.walk(caseStatement.d.pattern);

this._createFlowNarrowForPattern(node.d.expr, caseStatement);
// If the pattern involves targets that overwrite the subject
// expression, skip creating a flow node for narrowing the subject.
if (this._currentMatchSubjExpr) {
this._createFlowNarrowForPattern(node.d.expr, caseStatement);
this._currentMatchSubjExpr = undefined;
}

// Apply the guard expression.
if (caseStatement.d.guardExpr) {
Expand Down Expand Up @@ -2465,6 +2479,16 @@ export class Binder extends ParseTreeWalker {
const symbol = this._bindNameToScope(this._currentScope, target);
this._createAssignmentTargetFlowNodes(target, /* walkTargets */ false, /* unbound */ false);

// See if the target overwrites all or a portion of the subject expression.
if (this._currentMatchSubjExpr) {
if (
ParseTreeUtils.isMatchingExpression(target, this._currentMatchSubjExpr) ||
ParseTreeUtils.isPartialMatchingExpression(target, this._currentMatchSubjExpr)
) {
this._currentMatchSubjExpr = undefined;
}
}

if (symbol) {
const declaration: VariableDeclaration = {
type: DeclarationType.Variable,
Expand Down
7 changes: 6 additions & 1 deletion packages/pyright-internal/src/analyzer/patternMatching.ts
Original file line number Diff line number Diff line change
Expand Up @@ -859,20 +859,25 @@ function narrowTypeBasedOnClassPattern(
LocAddendum.typeNotClass().format({ type: evaluator.printType(exprType) }),
pattern.d.className
);
return NeverType.createNever();

return isPositiveTest ? UnknownType.create() : type;
} else if (isInstantiableClass(exprType)) {
if (ClassType.isProtocolClass(exprType) && !ClassType.isRuntimeCheckable(exprType)) {
evaluator.addDiagnostic(
DiagnosticRule.reportGeneralTypeIssues,
LocAddendum.protocolRequiresRuntimeCheckable(),
pattern.d.className
);

return isPositiveTest ? UnknownType.create() : type;
} else if (ClassType.isTypedDictClass(exprType)) {
evaluator.addDiagnostic(
DiagnosticRule.reportGeneralTypeIssues,
LocMessage.typedDictInClassPattern(),
pattern.d.className
);

return isPositiveTest ? UnknownType.create() : type;
}
}

Expand Down
34 changes: 34 additions & 0 deletions packages/pyright-internal/src/tests/samples/matchClass7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# This sample tests the case where a class pattern overwrites the subject
# expression.

from dataclasses import dataclass


@dataclass
class DC1:
val: str


def func1(val: DC1):
result = val

match result:
case DC1(result):
reveal_type(result, expected_text="str")


@dataclass
class DC2:
val: DC1


def func2(val: DC2):
result = val

match result.val:
case DC1(result):
reveal_type(result, expected_text="str")

# This should generate an error because result.val
# is no longer valid at this point.
print(result.val)
8 changes: 8 additions & 0 deletions packages/pyright-internal/src/tests/typeEvaluator6.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,14 @@ test('MatchClass6', () => {
TestUtils.validateResults(analysisResults, 0);
});

test('MatchClass7', () => {
const configOptions = new ConfigOptions(Uri.empty());

configOptions.defaultPythonVersion = pythonVersion3_10;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['matchClass7.py'], configOptions);
TestUtils.validateResults(analysisResults, 1);
});

test('MatchValue1', () => {
const configOptions = new ConfigOptions(Uri.empty());

Expand Down

0 comments on commit c616423

Please sign in to comment.