Skip to content

Commit

Permalink
Improved dataclass converters to support generic types. This addresses
Browse files Browse the repository at this point in the history
  • Loading branch information
erictraut authored Dec 2, 2024
1 parent aa3b270 commit a2c380e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 4 deletions.
35 changes: 31 additions & 4 deletions packages/pyright-internal/src/analyzer/dataClasses.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,15 @@ import {
OverloadedType,
TupleTypeArg,
Type,
TypeVarScopeType,
TypeVarType,
UnknownType,
Variance,
} from './types';
import {
addSolutionForSelfType,
applySolvedTypeVars,
buildSolution,
buildSolutionFromSpecializedClass,
computeMroLinearization,
convertToInstance,
Expand Down Expand Up @@ -576,9 +579,9 @@ export function synthesizeDataClassMethods(
entry.name,
getDescriptorForConverterField(
evaluator,
classType,
node,
entry.nameNode,
entry.typeAnnotationNode,
entry.converter,
entry.name,
fieldType,
Expand Down Expand Up @@ -876,7 +879,7 @@ function getConverterInputType(
): Type {
const converterType = getConverterAsFunction(
evaluator,
evaluator.getTypeOfExpression(converterNode.d.valueExpr).type
evaluator.getTypeOfExpression(converterNode.d.valueExpr, EvalFlags.NoSpecialize).type
);

if (!converterType) {
Expand Down Expand Up @@ -1005,9 +1008,9 @@ function getConverterAsFunction(
// type.
function getDescriptorForConverterField(
evaluator: TypeEvaluator,
dataclass: ClassType,
dataclassNode: ParseNode,
fieldNameNode: NameNode | undefined,
fieldAnnotationNode: TypeAnnotationNode | undefined,
converterNode: ParseNode,
fieldName: string,
getType: Type,
Expand All @@ -1027,6 +1030,26 @@ function getDescriptorForConverterField(
/* declaredMetaclass */ undefined,
isInstantiableClass(typeMetaclass) ? typeMetaclass : UnknownType.create()
);

const scopeId = getScopeIdForNode(converterNode);
descriptorClass.shared.typeVarScopeId = scopeId;

// Make the descriptor generic, copying the type parameters from the dataclass.
descriptorClass.shared.typeParams = dataclass.shared.typeParams.map((typeParm) => {
const typeParam = TypeVarType.cloneForScopeId(
typeParm,
scopeId,
descriptorClass.shared.name,
TypeVarScopeType.Class
);
typeParam.priv.computedVariance = Variance.Covariant;
return typeParam;
});

const solution = buildSolution(dataclass.shared.typeParams, descriptorClass.shared.typeParams);
getType = applySolvedTypeVars(getType, solution);
setType = applySolvedTypeVars(setType, solution);

descriptorClass.shared.baseClasses.push(evaluator.getBuiltInType(dataclassNode, 'object'));
computeMroLinearization(descriptorClass);

Expand Down Expand Up @@ -1067,7 +1090,11 @@ function getDescriptorForConverterField(
const getSymbol = Symbol.createWithType(SymbolFlags.ClassMember, getFunction);
fields.set('__get__', getSymbol);

return Symbol.createWithType(SymbolFlags.ClassMember, ClassType.cloneAsInstance(descriptorClass), fieldNameNode);
const descriptorInstance = ClassType.specialize(ClassType.cloneAsInstance(descriptorClass), [
...dataclass.shared.typeParams,
]);

return Symbol.createWithType(SymbolFlags.ClassMember, descriptorInstance, fieldNameNode);
}

// If the specified type is a descriptor — in particular, if it implements a
Expand Down
20 changes: 20 additions & 0 deletions packages/pyright-internal/src/tests/samples/dataclassConverter3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# This sample tests the case where a dataclass converter is used with
# a generic type.

from typing import Any, Callable, dataclass_transform


def model_field(*, converter: Callable[..., Any]) -> Any: ...


@dataclass_transform(field_specifiers=(model_field,))
class ModelBase: ...


class DC1[T](ModelBase):
data: set[T] = model_field(converter=set)


x = DC1([1, 2])
reveal_type(x, expected_text="DC1[int]")
reveal_type(x.data, expected_text="set[int]")
6 changes: 6 additions & 0 deletions packages/pyright-internal/src/tests/typeEvaluator4.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,12 @@ test('DataClassConverter2', () => {
TestUtils.validateResults(analysisResults, 4);
});

test('DataClassConverter3', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['dataclassConverter3.py']);

TestUtils.validateResults(analysisResults, 0);
});

test('DataClassPostInit1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['dataclassPostInit1.py']);

Expand Down

0 comments on commit a2c380e

Please sign in to comment.