Skip to content

Commit

Permalink
Fixed issue that results in non-deterministic false positive error, o…
Browse files Browse the repository at this point in the history
…ften relating to the "awaitable" check. This addresses #9204. (#9250)
  • Loading branch information
erictraut authored Oct 18, 2024
1 parent af73281 commit 83a9f5f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 25 deletions.
16 changes: 3 additions & 13 deletions packages/pyright-internal/src/analyzer/constructorTransform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import { ArgCategory, ExpressionNode, ParamCategory } from '../parser/parseNodes
import { ConstraintTracker } from './constraintTracker';
import { createFunctionFromConstructor } from './constructors';
import { getParamListDetails, ParamKind } from './parameterUtils';
import { Symbol, SymbolFlags } from './symbol';
import { Arg, FunctionResult, TypeEvaluator } from './typeEvaluatorTypes';
import {
AnyType,
Expand Down Expand Up @@ -124,11 +123,7 @@ function applyPartialTransform(
}

// Create a new copy of the functools.partial class that overrides the __call__ method.
const newPartialClass = ClassType.cloneForSymbolTableUpdate(result.returnType);
ClassType.getSymbolTable(newPartialClass).set(
'__call__',
Symbol.createWithType(SymbolFlags.ClassMember, transformResult.returnType)
);
const newPartialClass = ClassType.cloneForPartial(result.returnType, transformResult.returnType);

return {
returnType: newPartialClass,
Expand Down Expand Up @@ -176,9 +171,6 @@ function applyPartialTransform(
return undefined;
}

// Create a new copy of the functools.partial class that overrides the __call__ method.
const newPartialClass = ClassType.cloneForSymbolTableUpdate(result.returnType);

let synthesizedCallType: Type;
if (applicableOverloads.length === 1) {
synthesizedCallType = applicableOverloads[0];
Expand All @@ -191,10 +183,8 @@ function applyPartialTransform(
);
}

ClassType.getSymbolTable(newPartialClass).set(
'__call__',
Symbol.createWithType(SymbolFlags.ClassMember, synthesizedCallType)
);
// Create a new copy of the functools.partial class that overrides the __call__ method.
const newPartialClass = ClassType.cloneForPartial(result.returnType, synthesizedCallType);

return {
returnType: newPartialClass,
Expand Down
12 changes: 11 additions & 1 deletion packages/pyright-internal/src/analyzer/typeUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1806,7 +1806,8 @@ export function* getClassMemberIterator(

// Next look at class members.
if ((flags & MemberAccessFlags.SkipClassMembers) === 0) {
const symbol = memberFields.get(memberName);
let symbol = memberFields.get(memberName);

if (symbol && symbol.isClassMember()) {
const hasDeclaredType = symbol.hasTypedDeclarations();
if (!declaredTypesOnly || hasDeclaredType) {
Expand All @@ -1828,6 +1829,15 @@ export function* getClassMemberIterator(
}
}

// Handle the special case of a __call__ class member in a partial class.
if (
memberName === '__call__' &&
classType.priv.partialCallType &&
ClassType.isSameGenericClass(classType, specializedMroClass)
) {
symbol = Symbol.createWithType(SymbolFlags.ClassMember, classType.priv.partialCallType);
}

const cm: ClassMember = {
symbol,
isInstanceMember,
Expand Down
16 changes: 5 additions & 11 deletions packages/pyright-internal/src/analyzer/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,9 @@ export interface ClassDetailsPriv {
// the "deprecated" class. This allows these instances to be used
// as decorators for other classes or functions.
deprecatedInstanceMessage?: string | undefined;

// Special-case fields for partial class.
partialCallType?: Type | undefined;
}

export interface ClassType extends TypeBase<TypeCategory.Class> {
Expand Down Expand Up @@ -1002,12 +1005,9 @@ export namespace ClassType {
return newClassType;
}

export function cloneForSymbolTableUpdate(classType: ClassType): ClassType {
export function cloneForPartial(classType: ClassType, partialCallType: Type): ClassType {
const newClassType = TypeBase.cloneType(classType);
newClassType.shared = { ...newClassType.shared };
newClassType.shared.fields = new Map(newClassType.shared.fields);
newClassType.shared.mro = Array.from(newClassType.shared.mro);
newClassType.shared.mro[0] = cloneAsInstantiable(newClassType);
newClassType.priv.partialCallType = partialCallType;
return newClassType;
}

Expand Down Expand Up @@ -3362,12 +3362,6 @@ export function isTypeSame(type1: Type, type2: Type, options: TypeSameOptions =
return false;
}

// This test is required for the "partial" class, which clones
// the symbol table to add a custom __call__ method.
if (type1.shared.fields !== classType2.shared.fields) {
return false;
}

if (!type1.priv.isUnpacked !== !classType2.priv.isUnpacked) {
return false;
}
Expand Down

0 comments on commit 83a9f5f

Please sign in to comment.