Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize checking involving large discriminated union types #42556

Merged
merged 18 commits into from
Mar 1, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 124 additions & 20 deletions src/compiler/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17533,8 +17533,17 @@ namespace ts {

function typeRelatedToSomeType(source: Type, target: UnionOrIntersectionType, reportErrors: boolean): Ternary {
const targetTypes = target.types;
if (target.flags & TypeFlags.Union && containsType(targetTypes, source)) {
return Ternary.True;
if (target.flags & TypeFlags.Union) {
if (containsType(targetTypes, source)) {
return Ternary.True;
}
const match = getMatchingUnionConstituentForType(<UnionType>target, source);
if (match) {
const related = isRelatedTo(source, match, /*reportErrors*/ false);
if (related) {
return related;
}
}
}
for (const type of targetTypes) {
const related = isRelatedTo(source, type, /*reportErrors*/ false);
Expand Down Expand Up @@ -21371,6 +21380,82 @@ namespace ts {
return result;
}

// Given a set of constituent types and a property name, create and return a map keyed by the literal
// types of the property by that name in each constituent type. No map is returned if some key property
// has a non-literal type or if less than 10 or less than 50% of the constituents have a unique key.
// Entries with duplicate keys have unknownType as the value.
function mapTypesByKeyProperty(types: Type[], name: __String) {
const map = new Map<TypeId, Type>();
let count = 0;
for (const type of types) {
if (type.flags & (TypeFlags.Object | TypeFlags.Intersection | TypeFlags.InstantiableNonPrimitive)) {
const discriminant = getTypeOfPropertyOfType(type, name);
if (discriminant) {
if (!isLiteralType(discriminant)) {
return undefined;
}
let duplicate = false;
forEachType(discriminant, t => {
const id = getTypeId(getRegularTypeOfLiteralType(t));
const existing = map.get(id);
if (!existing) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let count = 0;
// The count of types with unique keys.
let count = 0;

map.set(id, type);
}
else if (existing !== unknownType) {
map.set(id, unknownType);
duplicate = true;
}
});
if (!duplicate) count++;
}
}
}
return count >= 10 && count * 2 >= types.length ? map : undefined;
}

// Return the name of a discriminant property for which it was possible and feasible to construct a map of
// constituent types keyed by the literal types of the property by that name in each constituent type.
function getKeyPropertyName(unionType: UnionType): __String | undefined {
const types = unionType.types;
// We only construct maps for large unions with non-primitive constituents.
if (types.length < 10 || getObjectFlags(unionType) & ObjectFlags.PrimitiveUnion) {
return undefined;
}
if (unionType.keyPropertyName === undefined) {
// The candidate key property name is the name of the first property with a unit type in one of the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To bring this in line with the comments above:

Suggested change
return count >= 10 && count * 2 >= types.length ? map : undefined;
return count < 10 || types.length < count * 2 ? undefined : map;

// constituent types.
const keyPropertyName = forEach(types, t =>
t.flags & (TypeFlags.Object | TypeFlags.Intersection | TypeFlags.InstantiableNonPrimitive) ?
forEach(getPropertiesOfType(t), p => isUnitType(getTypeOfSymbol(p)) ? p.escapedName : undefined) :
undefined);
const mapByKeyProperty = keyPropertyName && mapTypesByKeyProperty(types, keyPropertyName);
unionType.keyPropertyName = mapByKeyProperty ? keyPropertyName : "" as __String;
unionType.constituentMap = mapByKeyProperty;
}
return (unionType.keyPropertyName as string).length ? unionType.keyPropertyName : undefined;
}

// Given a union type for which getKeyPropertyName returned a non-undefined result, return the constituent
// that corresponds to the given key type for that property name.
function getConstituentTypeForKeyType(unionType: UnionType, keyType: Type) {
const result = unionType.constituentMap?.get(getTypeId(getRegularTypeOfLiteralType(keyType)));
return result !== unknownType ? result : undefined;
}

function getMatchingUnionConstituentForType(unionType: UnionType, type: Type) {
const keyPropertyName = getKeyPropertyName(unionType);
const propType = keyPropertyName && getTypeOfPropertyOfType(type, keyPropertyName);
return propType && getConstituentTypeForKeyType(unionType, propType);
}

function getMatchingUnionConstituentForObjectLiteral(unionType: UnionType, node: ObjectLiteralExpression) {
const keyPropertyName = getKeyPropertyName(unionType);
const propNode = keyPropertyName && find(node.properties, p => p.symbol && p.kind === SyntaxKind.PropertyAssignment &&
p.symbol.escapedName === keyPropertyName && isPossiblyDiscriminantValue(p.initializer));
const propType = propNode && getTypeOfExpression((<PropertyAssignment>propNode).initializer);
return propType && getConstituentTypeForKeyType(unionType, propType);
}

function isOrContainsMatchingReference(source: Node, target: Node) {
return isMatchingReference(source, target) || containsMatchingReference(source, target);
}
Expand Down Expand Up @@ -22473,8 +22558,7 @@ namespace ts {
}
}
if (isMatchingReferenceDiscriminant(expr, type)) {
type = narrowTypeByDiscriminant(type, expr as AccessExpression,
t => narrowTypeBySwitchOnDiscriminant(t, flow.switchStatement, flow.clauseStart, flow.clauseEnd));
type = narrowTypeBySwitchOnDiscriminantProperty(type, expr as AccessExpression, flow.switchStatement, flow.clauseStart, flow.clauseEnd);
}
}
return createFlowType(type, isIncomplete(flowType));
Expand Down Expand Up @@ -22648,8 +22732,7 @@ namespace ts {
if (propName === undefined) {
return type;
}
const includesNullable = strictNullChecks && maybeTypeOfKind(type, TypeFlags.Nullable);
const removeNullable = includesNullable && isOptionalChain(access);
const removeNullable = strictNullChecks && isOptionalChain(access) && maybeTypeOfKind(type, TypeFlags.Nullable);
let propType = getTypeOfPropertyOfType(removeNullable ? getTypeWithFacts(type, TypeFacts.NEUndefinedOrNull) : type, propName);
if (!propType) {
return type;
Expand All @@ -22662,6 +22745,28 @@ namespace ts {
});
}

function narrowTypeByDiscriminantProperty(type: Type, access: AccessExpression, operator: SyntaxKind, value: Expression, assumeTrue: boolean) {
if ((operator === SyntaxKind.EqualsEqualsEqualsToken || operator === SyntaxKind.ExclamationEqualsEqualsToken) && type.flags & TypeFlags.Union &&
getKeyPropertyName(<UnionType>type) === getAccessedPropertyName(access)) {
const candidate = getConstituentTypeForKeyType(<UnionType>type, getTypeOfExpression(value));
if (candidate) {
return operator === (assumeTrue ? SyntaxKind.EqualsEqualsEqualsToken : SyntaxKind.ExclamationEqualsEqualsToken) ? candidate : filterType(type, t => t !== candidate);
}
}
return narrowTypeByDiscriminant(type, access, t => narrowTypeByEquality(t, operator, value, assumeTrue));
}

function narrowTypeBySwitchOnDiscriminantProperty(type: Type, access: AccessExpression, switchStatement: SwitchStatement, clauseStart: number, clauseEnd: number) {
if (clauseStart < clauseEnd && type.flags & TypeFlags.Union && getKeyPropertyName(<UnionType>type) === getAccessedPropertyName(access)) {
const clauseTypes = getSwitchClauseTypes(switchStatement).slice(clauseStart, clauseEnd);
const candidate = getUnionType(map(clauseTypes, t => getConstituentTypeForKeyType(<UnionType>type, t) || unknownType));
if (candidate !== unknownType) {
return candidate;
}
}
return narrowTypeByDiscriminant(type, access, t => narrowTypeBySwitchOnDiscriminant(t, switchStatement, clauseStart, clauseEnd));
}

function narrowTypeByTruthiness(type: Type, expr: Expression, assumeTrue: boolean): Type {
if (isMatchingReference(reference, expr)) {
return getTypeWithFacts(type, assumeTrue ? TypeFacts.Truthy : TypeFacts.Falsy);
Expand Down Expand Up @@ -22731,10 +22836,10 @@ namespace ts {
}
}
if (isMatchingReferenceDiscriminant(left, type)) {
return narrowTypeByDiscriminant(type, <AccessExpression>left, t => narrowTypeByEquality(t, operator, right, assumeTrue));
return narrowTypeByDiscriminantProperty(type, <AccessExpression>left, operator, right, assumeTrue);
}
if (isMatchingReferenceDiscriminant(right, type)) {
return narrowTypeByDiscriminant(type, <AccessExpression>right, t => narrowTypeByEquality(t, operator, left, assumeTrue));
return narrowTypeByDiscriminantProperty(type, <AccessExpression>right, operator, left, assumeTrue);
}
if (isMatchingConstructorReference(left)) {
return narrowTypeByConstructor(type, operator, right, assumeTrue);
Expand Down Expand Up @@ -22807,7 +22912,7 @@ namespace ts {
}
if (assumeTrue) {
const filterFn: (t: Type) => boolean = operator === SyntaxKind.EqualsEqualsToken ?
(t => areTypesComparable(t, valueType) || isCoercibleUnderDoubleEquals(t, valueType)) :
t => areTypesComparable(t, valueType) || isCoercibleUnderDoubleEquals(t, valueType) :
t => areTypesComparable(t, valueType);
return replacePrimitivesWithLiterals(filterType(type, filterFn), valueType);
}
Expand Down Expand Up @@ -24616,7 +24721,7 @@ namespace ts {
}

function discriminateContextualTypeByObjectMembers(node: ObjectLiteralExpression, contextualType: UnionType) {
return discriminateTypeByDiscriminableItems(contextualType,
return getMatchingUnionConstituentForObjectLiteral(contextualType, node) || discriminateTypeByDiscriminableItems(contextualType,
map(
filter(node.properties, p => !!p.symbol && p.kind === SyntaxKind.PropertyAssignment && isPossiblyDiscriminantValue(p.initializer) && isDiscriminantProperty(contextualType, p.symbol.escapedName)),
prop => ([() => checkExpression((prop as PropertyAssignment).initializer), prop.symbol.escapedName] as [() => Type, __String])
Expand Down Expand Up @@ -24646,15 +24751,9 @@ namespace ts {
const instantiatedType = instantiateContextualType(contextualType, node, contextFlags);
if (instantiatedType && !(contextFlags && contextFlags & ContextFlags.NoConstraints && instantiatedType.flags & TypeFlags.TypeVariable)) {
const apparentType = mapType(instantiatedType, getApparentType, /*noReductions*/ true);
if (apparentType.flags & TypeFlags.Union) {
if (isObjectLiteralExpression(node)) {
return discriminateContextualTypeByObjectMembers(node, apparentType as UnionType);
}
else if (isJsxAttributes(node)) {
return discriminateContextualTypeByJSXAttributes(node, apparentType as UnionType);
}
}
return apparentType;
return apparentType.flags & TypeFlags.Union && isObjectLiteralExpression(node) ? discriminateContextualTypeByObjectMembers(node, apparentType as UnionType) :
apparentType.flags & TypeFlags.Union && isJsxAttributes(node) ? discriminateContextualTypeByJSXAttributes(node, apparentType as UnionType) :
apparentType;
}
}

Expand Down Expand Up @@ -25086,8 +25185,9 @@ namespace ts {
if (forceTuple || inConstContext || contextualType && forEachType(contextualType, isTupleLikeType)) {
return createArrayLiteralType(createTupleType(elementTypes, elementFlags, /*readonly*/ inConstContext));
}
const reduction = !contextualType || checkMode && checkMode & CheckMode.Inferential ? UnionReduction.Subtype : UnionReduction.Literal;
return createArrayLiteralType(createArrayType(elementTypes.length ?
getUnionType(sameMap(elementTypes, (t, i) => elementFlags[i] & ElementFlags.Variadic ? getIndexedAccessTypeOrUndefined(t, numberType) || anyType : t), UnionReduction.Subtype) :
getUnionType(sameMap(elementTypes, (t, i) => elementFlags[i] & ElementFlags.Variadic ? getIndexedAccessTypeOrUndefined(t, numberType) || anyType : t), reduction) :
strictNullChecks ? implicitNeverType : undefinedWideningType, inConstContext));
}

Expand Down Expand Up @@ -41041,6 +41141,10 @@ namespace ts {
// Keep this up-to-date with the same logic within `getApparentTypeOfContextualType`, since they should behave similarly
function findMatchingDiscriminantType(source: Type, target: Type, isRelatedTo: (source: Type, target: Type) => Ternary, skipPartial?: boolean) {
if (target.flags & TypeFlags.Union && source.flags & (TypeFlags.Intersection | TypeFlags.Object)) {
const match = getMatchingUnionConstituentForType(<UnionType>target, source);
if (match) {
return match;
}
const sourceProperties = getPropertiesOfType(source);
if (sourceProperties) {
const sourcePropertiesFiltered = findDiscriminantProperties(sourceProperties, target);
Expand Down
4 changes: 4 additions & 0 deletions src/compiler/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5303,6 +5303,10 @@ namespace ts {
regularType?: UnionType;
/* @internal */
origin?: Type; // Denormalized union, intersection, or index type in which union originates
/* @internal */
keyPropertyName?: __String; // Property with unique unit type that exists in every object/intersection in union type
/* @internal */
constituentMap?: ESMap<TypeId, Type>; // Constituents keyed by unit type discriminants
}

export interface IntersectionType extends UnionOrIntersectionType {
Expand Down
6 changes: 3 additions & 3 deletions tests/baselines/reference/arrayBestCommonTypes.types
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ module EmptyTypes {
>t1 : { x: number; y: base; }[]
>x : number
>y : base
>[{ x: 7, y: new derived() }, { x: 5, y: new base() }] : { x: number; y: derived; }[]
>[{ x: 7, y: new derived() }, { x: 5, y: new base() }] : ({ x: number; y: derived; } | { x: number; y: base; })[]
>{ x: 7, y: new derived() } : { x: number; y: derived; }
>x : number
>7 : 7
Expand Down Expand Up @@ -267,7 +267,7 @@ module EmptyTypes {
>t3 : { x: string; y: base; }[]
>x : string
>y : base
>[{ x: undefined, y: new base() }, { x: '', y: new derived() }] : { x: string; y: derived; }[]
>[{ x: undefined, y: new base() }, { x: '', y: new derived() }] : ({ x: undefined; y: base; } | { x: string; y: derived; })[]
>{ x: undefined, y: new base() } : { x: undefined; y: base; }
>x : undefined
>undefined : undefined
Expand Down Expand Up @@ -627,7 +627,7 @@ module NonEmptyTypes {
>t1 : { x: number; y: base; }[]
>x : number
>y : base
>[{ x: 7, y: new derived() }, { x: 5, y: new base() }] : { x: number; y: base; }[]
>[{ x: 7, y: new derived() }, { x: 5, y: new base() }] : ({ x: number; y: derived; } | { x: number; y: base; })[]
>{ x: 7, y: new derived() } : { x: number; y: derived; }
>x : number
>7 : 7
Expand Down
4 changes: 2 additions & 2 deletions tests/baselines/reference/arrayLiteralTypeInference.types
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ var x2: Action[] = [

var x3: Action[] = [
>x3 : Action[]
>[ new Action(), new ActionA(), new ActionB()] : Action[]
>[ new Action(), new ActionA(), new ActionB()] : (Action | ActionA | ActionB)[]

new Action(),
>new Action() : Action
Expand Down Expand Up @@ -119,7 +119,7 @@ var z3: { id: number }[] =
>id : number

[
>[ new Action(), new ActionA(), new ActionB() ] : Action[]
>[ new Action(), new ActionA(), new ActionB() ] : (Action | ActionA | ActionB)[]

new Action(),
>new Action() : Action
Expand Down
Loading