From a2c380ecb4c005f3cea4db2fb5e552eccff8fae3 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Mon, 2 Dec 2024 05:08:46 -0800 Subject: [PATCH] Improved dataclass converters to support generic types. This addresses #9512. (#9528) --- .../src/analyzer/dataClasses.ts | 35 ++++++++++++++++--- .../src/tests/samples/dataclassConverter3.py | 20 +++++++++++ .../src/tests/typeEvaluator4.test.ts | 6 ++++ 3 files changed, 57 insertions(+), 4 deletions(-) create mode 100644 packages/pyright-internal/src/tests/samples/dataclassConverter3.py diff --git a/packages/pyright-internal/src/analyzer/dataClasses.ts b/packages/pyright-internal/src/analyzer/dataClasses.ts index b56aa9773732..b7c705fffd8f 100644 --- a/packages/pyright-internal/src/analyzer/dataClasses.ts +++ b/packages/pyright-internal/src/analyzer/dataClasses.ts @@ -63,12 +63,15 @@ import { OverloadedType, TupleTypeArg, Type, + TypeVarScopeType, TypeVarType, UnknownType, + Variance, } from './types'; import { addSolutionForSelfType, applySolvedTypeVars, + buildSolution, buildSolutionFromSpecializedClass, computeMroLinearization, convertToInstance, @@ -576,9 +579,9 @@ export function synthesizeDataClassMethods( entry.name, getDescriptorForConverterField( evaluator, + classType, node, entry.nameNode, - entry.typeAnnotationNode, entry.converter, entry.name, fieldType, @@ -876,7 +879,7 @@ function getConverterInputType( ): Type { const converterType = getConverterAsFunction( evaluator, - evaluator.getTypeOfExpression(converterNode.d.valueExpr).type + evaluator.getTypeOfExpression(converterNode.d.valueExpr, EvalFlags.NoSpecialize).type ); if (!converterType) { @@ -1005,9 +1008,9 @@ function getConverterAsFunction( // type. function getDescriptorForConverterField( evaluator: TypeEvaluator, + dataclass: ClassType, dataclassNode: ParseNode, fieldNameNode: NameNode | undefined, - fieldAnnotationNode: TypeAnnotationNode | undefined, converterNode: ParseNode, fieldName: string, getType: Type, @@ -1027,6 +1030,26 @@ function getDescriptorForConverterField( /* declaredMetaclass */ undefined, isInstantiableClass(typeMetaclass) ? typeMetaclass : UnknownType.create() ); + + const scopeId = getScopeIdForNode(converterNode); + descriptorClass.shared.typeVarScopeId = scopeId; + + // Make the descriptor generic, copying the type parameters from the dataclass. + descriptorClass.shared.typeParams = dataclass.shared.typeParams.map((typeParm) => { + const typeParam = TypeVarType.cloneForScopeId( + typeParm, + scopeId, + descriptorClass.shared.name, + TypeVarScopeType.Class + ); + typeParam.priv.computedVariance = Variance.Covariant; + return typeParam; + }); + + const solution = buildSolution(dataclass.shared.typeParams, descriptorClass.shared.typeParams); + getType = applySolvedTypeVars(getType, solution); + setType = applySolvedTypeVars(setType, solution); + descriptorClass.shared.baseClasses.push(evaluator.getBuiltInType(dataclassNode, 'object')); computeMroLinearization(descriptorClass); @@ -1067,7 +1090,11 @@ function getDescriptorForConverterField( const getSymbol = Symbol.createWithType(SymbolFlags.ClassMember, getFunction); fields.set('__get__', getSymbol); - return Symbol.createWithType(SymbolFlags.ClassMember, ClassType.cloneAsInstance(descriptorClass), fieldNameNode); + const descriptorInstance = ClassType.specialize(ClassType.cloneAsInstance(descriptorClass), [ + ...dataclass.shared.typeParams, + ]); + + return Symbol.createWithType(SymbolFlags.ClassMember, descriptorInstance, fieldNameNode); } // If the specified type is a descriptor — in particular, if it implements a diff --git a/packages/pyright-internal/src/tests/samples/dataclassConverter3.py b/packages/pyright-internal/src/tests/samples/dataclassConverter3.py new file mode 100644 index 000000000000..5b6a59e1a602 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/dataclassConverter3.py @@ -0,0 +1,20 @@ +# This sample tests the case where a dataclass converter is used with +# a generic type. + +from typing import Any, Callable, dataclass_transform + + +def model_field(*, converter: Callable[..., Any]) -> Any: ... + + +@dataclass_transform(field_specifiers=(model_field,)) +class ModelBase: ... + + +class DC1[T](ModelBase): + data: set[T] = model_field(converter=set) + + +x = DC1([1, 2]) +reveal_type(x, expected_text="DC1[int]") +reveal_type(x.data, expected_text="set[int]") diff --git a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts index d2d0c897dbfc..61a69c30485d 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts @@ -441,6 +441,12 @@ test('DataClassConverter2', () => { TestUtils.validateResults(analysisResults, 4); }); +test('DataClassConverter3', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['dataclassConverter3.py']); + + TestUtils.validateResults(analysisResults, 0); +}); + test('DataClassPostInit1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['dataclassPostInit1.py']);