Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve union origin preservation in filtering-unionizing binary expressions #61362

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions src/compiler/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17409,7 +17409,7 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
return false;
}

function addTypeToUnion(typeSet: Type[], includes: TypeFlags, type: Type) {
function addTypeToUnion(typeSet: Type[] | undefined, includes: TypeFlags, type: Type) {
const flags = type.flags;
// We ignore 'never' types in unions
if (!(flags & TypeFlags.Never)) {
Expand All @@ -17421,7 +17421,7 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
if (!strictNullChecks && flags & TypeFlags.Nullable) {
if (!(getObjectFlags(type) & ObjectFlags.ContainsWideningType)) includes |= TypeFlags.IncludesNonWideningType;
}
else {
else if (typeSet) {
const len = typeSet.length;
const index = len && type.id > typeSet[len - 1].id ? ~len : binarySearch(typeSet, type, getTypeId, compareValues);
if (index < 0) {
Expand All @@ -17434,7 +17434,7 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {

// Add the given types to the given type set. Order is preserved, duplicates are removed,
// and nested types of the given kind are flattened into the set.
function addTypesToUnion(typeSet: Type[], includes: TypeFlags, types: readonly Type[]): TypeFlags {
function addTypesToUnion(typeSet: Type[] | undefined, includes: TypeFlags, types: readonly Type[]): TypeFlags {
let lastType: Type | undefined;
for (const type of types) {
// We skip the type if it is the same as the last type we processed. This simple test particularly
Expand Down Expand Up @@ -19644,6 +19644,10 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
return !!(type.flags & TypeFlags.Freshable) && (type as LiteralType).freshType === type;
}

function isRegularLiteralType(type: Type) {
return !!(type.flags & TypeFlags.Freshable) && (type as LiteralType).regularType === type;
}

function getStringLiteralType(value: string): StringLiteralType {
let type;
return stringLiteralTypes.get(value) ||
Expand Down Expand Up @@ -25036,10 +25040,6 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
return filterType(type, t => hasTypeFacts(t, TypeFacts.Truthy));
}

function extractDefinitelyFalsyTypes(type: Type): Type {
return mapType(type, getDefinitelyFalsyPartOfType);
}

function getDefinitelyFalsyPartOfType(type: Type): Type {
return type.flags & TypeFlags.String ? emptyStringType :
type.flags & TypeFlags.Number ? zeroType :
Expand Down Expand Up @@ -40160,7 +40160,7 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
case SyntaxKind.AmpersandAmpersandToken:
case SyntaxKind.AmpersandAmpersandEqualsToken: {
const resultType = hasTypeFacts(leftType, TypeFacts.Truthy) ?
getUnionType([extractDefinitelyFalsyTypes(strictNullChecks ? leftType : getBaseTypeOfLiteralType(rightType)), rightType]) :
getUnionOfLeftAndRightTypes(strictNullChecks ? leftType : getBaseTypeOfLiteralType(rightType), rightType, getDefinitelyFalsyPartOfType) :
leftType;
if (operator === SyntaxKind.AmpersandAmpersandEqualsToken) {
checkAssignmentOperator(rightType);
Expand All @@ -40170,7 +40170,7 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
case SyntaxKind.BarBarToken:
case SyntaxKind.BarBarEqualsToken: {
const resultType = hasTypeFacts(leftType, TypeFacts.Falsy) ?
getUnionType([getNonNullableType(removeDefinitelyFalsyTypes(leftType)), rightType], UnionReduction.Subtype) :
getUnionOfLeftAndRightTypes(leftType, rightType, t => hasTypeFacts(t, TypeFacts.Truthy) ? getNonNullableType(t) : neverType, UnionReduction.Subtype) :
leftType;
if (operator === SyntaxKind.BarBarEqualsToken) {
checkAssignmentOperator(rightType);
Expand All @@ -40180,7 +40180,7 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
case SyntaxKind.QuestionQuestionToken:
case SyntaxKind.QuestionQuestionEqualsToken: {
const resultType = hasTypeFacts(leftType, TypeFacts.EQUndefinedOrNull) ?
getUnionType([getNonNullableType(leftType), rightType], UnionReduction.Subtype) :
getUnionOfLeftAndRightTypes(leftType, rightType, getNonNullableType, UnionReduction.Subtype) :
leftType;
if (operator === SyntaxKind.QuestionQuestionEqualsToken) {
checkAssignmentOperator(rightType);
Expand Down Expand Up @@ -40415,6 +40415,20 @@ export function createTypeChecker(host: TypeCheckerHost): TypeChecker {
}
return false;
}

function getUnionOfLeftAndRightTypes(leftType: Type, rightType: Type, adjustLeft: (type: Type) => Type, unionReduction?: UnionReduction) {
const rightTypes = rightType.flags & TypeFlags.Union ? (rightType as UnionType).types : [rightType];
const includes = addTypesToUnion(/*typeSet*/ undefined, 0 as TypeFlags, rightTypes) & (TypeFlags.BaseOfLiteral | TypeFlags.Nullable);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not like all 3 cases require both BaseOfLiteral and Nullable preservation but there is no harm in treating them the same way here

return getUnionType([
mapType(
leftType,
// when something could be removed from the left type and when it's in the right type it means it would be re-added right away
// in such a case it's preserved in the mapped left type to help with origin/alias preservation
t => includes & t.flags || isRegularLiteralType(t) && containsType(rightTypes, (t as LiteralType).freshType) ? t : adjustLeft(t),
),
rightType,
], unionReduction);
}
}

function getBaseTypesIfUnrelated(leftType: Type, rightType: Type, isRelated: (left: Type, right: Type) => boolean): [Type, Type] {
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6306,6 +6306,8 @@ export const enum TypeFlags {
/** @internal */
Nullable = Undefined | Null,
Literal = StringLiteral | NumberLiteral | BigIntLiteral | BooleanLiteral,
/** @internal */
BaseOfLiteral = String | Number | BigInt | Boolean,
Unit = Enum | Literal | UniqueESSymbol | Nullable,
Freshable = Enum | Literal,
StringOrNumberLiteral = StringLiteral | NumberLiteral,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//// [tests/cases/compiler/unionBinaryExpressionPreserveOrigin1.ts] ////

=== unionBinaryExpressionPreserveOrigin1.ts ===
// https://github.com/microsoft/TypeScript/issues/43031

type Brand<K, T> = K & { __brand: T };
>Brand : Symbol(Brand, Decl(unionBinaryExpressionPreserveOrigin1.ts, 0, 0))
>K : Symbol(K, Decl(unionBinaryExpressionPreserveOrigin1.ts, 2, 11))
>T : Symbol(T, Decl(unionBinaryExpressionPreserveOrigin1.ts, 2, 13))
>K : Symbol(K, Decl(unionBinaryExpressionPreserveOrigin1.ts, 2, 11))
>__brand : Symbol(__brand, Decl(unionBinaryExpressionPreserveOrigin1.ts, 2, 24))
>T : Symbol(T, Decl(unionBinaryExpressionPreserveOrigin1.ts, 2, 13))

type BrandedUnknown<T> = Brand<"unknown", T>;
>BrandedUnknown : Symbol(BrandedUnknown, Decl(unionBinaryExpressionPreserveOrigin1.ts, 2, 38))
>T : Symbol(T, Decl(unionBinaryExpressionPreserveOrigin1.ts, 3, 20))
>Brand : Symbol(Brand, Decl(unionBinaryExpressionPreserveOrigin1.ts, 0, 0))
>T : Symbol(T, Decl(unionBinaryExpressionPreserveOrigin1.ts, 3, 20))

type Maybe<T> = T | BrandedUnknown<T>;
>Maybe : Symbol(Maybe, Decl(unionBinaryExpressionPreserveOrigin1.ts, 3, 45))
>T : Symbol(T, Decl(unionBinaryExpressionPreserveOrigin1.ts, 4, 11))
>T : Symbol(T, Decl(unionBinaryExpressionPreserveOrigin1.ts, 4, 11))
>BrandedUnknown : Symbol(BrandedUnknown, Decl(unionBinaryExpressionPreserveOrigin1.ts, 2, 38))
>T : Symbol(T, Decl(unionBinaryExpressionPreserveOrigin1.ts, 4, 11))

declare const m1: Maybe<boolean> | undefined;
>m1 : Symbol(m1, Decl(unionBinaryExpressionPreserveOrigin1.ts, 6, 13))
>Maybe : Symbol(Maybe, Decl(unionBinaryExpressionPreserveOrigin1.ts, 3, 45))

const test1 = m1 || false;
>test1 : Symbol(test1, Decl(unionBinaryExpressionPreserveOrigin1.ts, 7, 5))
>m1 : Symbol(m1, Decl(unionBinaryExpressionPreserveOrigin1.ts, 6, 13))

const test2 = m1 ?? false;
>test2 : Symbol(test2, Decl(unionBinaryExpressionPreserveOrigin1.ts, 8, 5))
>m1 : Symbol(m1, Decl(unionBinaryExpressionPreserveOrigin1.ts, 6, 13))

declare const m2: Maybe<null> | undefined;
>m2 : Symbol(m2, Decl(unionBinaryExpressionPreserveOrigin1.ts, 10, 13))
>Maybe : Symbol(Maybe, Decl(unionBinaryExpressionPreserveOrigin1.ts, 3, 45))

const test3 = m2 || null;
>test3 : Symbol(test3, Decl(unionBinaryExpressionPreserveOrigin1.ts, 11, 5))
>m2 : Symbol(m2, Decl(unionBinaryExpressionPreserveOrigin1.ts, 10, 13))

const test4 = m2 ?? null;
>test4 : Symbol(test4, Decl(unionBinaryExpressionPreserveOrigin1.ts, 12, 5))
>m2 : Symbol(m2, Decl(unionBinaryExpressionPreserveOrigin1.ts, 10, 13))

type StrOrNum = string | number
>StrOrNum : Symbol(StrOrNum, Decl(unionBinaryExpressionPreserveOrigin1.ts, 12, 25))

declare const numOrStr: StrOrNum;
>numOrStr : Symbol(numOrStr, Decl(unionBinaryExpressionPreserveOrigin1.ts, 15, 13))
>StrOrNum : Symbol(StrOrNum, Decl(unionBinaryExpressionPreserveOrigin1.ts, 12, 25))

const test5 = numOrStr && numOrStr;
>test5 : Symbol(test5, Decl(unionBinaryExpressionPreserveOrigin1.ts, 16, 5))
>numOrStr : Symbol(numOrStr, Decl(unionBinaryExpressionPreserveOrigin1.ts, 15, 13))
>numOrStr : Symbol(numOrStr, Decl(unionBinaryExpressionPreserveOrigin1.ts, 15, 13))

Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//// [tests/cases/compiler/unionBinaryExpressionPreserveOrigin1.ts] ////

=== unionBinaryExpressionPreserveOrigin1.ts ===
// https://github.com/microsoft/TypeScript/issues/43031

type Brand<K, T> = K & { __brand: T };
>Brand : Brand<K, T>
> : ^^^^^^^^^^^
>__brand : T
> : ^

type BrandedUnknown<T> = Brand<"unknown", T>;
>BrandedUnknown : BrandedUnknown<T>
> : ^^^^^^^^^^^^^^^^^

type Maybe<T> = T | BrandedUnknown<T>;
>Maybe : Maybe<T>
> : ^^^^^^^^

declare const m1: Maybe<boolean> | undefined;
>m1 : Maybe<boolean> | undefined
> : ^^^^^^^^^^^^^^^^^^^^^^^^^^

const test1 = m1 || false;
>test1 : Maybe<boolean>
> : ^^^^^^^^^^^^^^
>m1 || false : Maybe<boolean>
> : ^^^^^^^^^^^^^^
>m1 : Maybe<boolean> | undefined
> : ^^^^^^^^^^^^^^^^^^^^^^^^^^
>false : false
> : ^^^^^

const test2 = m1 ?? false;
>test2 : Maybe<boolean>
> : ^^^^^^^^^^^^^^
>m1 ?? false : Maybe<boolean>
> : ^^^^^^^^^^^^^^
>m1 : Maybe<boolean> | undefined
> : ^^^^^^^^^^^^^^^^^^^^^^^^^^
>false : false
> : ^^^^^

declare const m2: Maybe<null> | undefined;
>m2 : Maybe<null> | undefined
> : ^^^^^^^^^^^^^^^^^^^^^^^

const test3 = m2 || null;
>test3 : Maybe<null>
> : ^^^^^^^^^^^
>m2 || null : Maybe<null>
> : ^^^^^^^^^^^
>m2 : Maybe<null> | undefined
> : ^^^^^^^^^^^^^^^^^^^^^^^

const test4 = m2 ?? null;
>test4 : Maybe<null>
> : ^^^^^^^^^^^
>m2 ?? null : Maybe<null>
> : ^^^^^^^^^^^
>m2 : Maybe<null> | undefined
> : ^^^^^^^^^^^^^^^^^^^^^^^

type StrOrNum = string | number
>StrOrNum : StrOrNum
> : ^^^^^^^^

declare const numOrStr: StrOrNum;
>numOrStr : StrOrNum
> : ^^^^^^^^

const test5 = numOrStr && numOrStr;
>test5 : StrOrNum
> : ^^^^^^^^
>numOrStr && numOrStr : StrOrNum
> : ^^^^^^^^
>numOrStr : StrOrNum
> : ^^^^^^^^
>numOrStr : StrOrNum
> : ^^^^^^^^

20 changes: 20 additions & 0 deletions tests/cases/compiler/unionBinaryExpressionPreserveOrigin1.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// @strict: true
// @noEmit: true

// https://github.com/microsoft/TypeScript/issues/43031

type Brand<K, T> = K & { __brand: T };
type BrandedUnknown<T> = Brand<"unknown", T>;
type Maybe<T> = T | BrandedUnknown<T>;

declare const m1: Maybe<boolean> | undefined;
const test1 = m1 || false;
const test2 = m1 ?? false;

declare const m2: Maybe<null> | undefined;
const test3 = m2 || null;
const test4 = m2 ?? null;

type StrOrNum = string | number
declare const numOrStr: StrOrNum;
const test5 = numOrStr && numOrStr;