Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: simplify computed types #866

Merged
merged 10 commits into from
Feb 7, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,13 @@ export class FloatConstant extends NumberConstant {
}

override toString(): string {
return this.value.toString();
const string = this.value.toString();

if (!string.includes('.')) {
return `${string}.0`;
} else {
return string;
}
}
}

Expand Down
33 changes: 30 additions & 3 deletions packages/safe-ds-lang/src/language/typing/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,10 @@ export class ClassType extends NamedType<SdsClass> {
}

override updateNullability(isNullable: boolean): ClassType {
if (this.isNullable === isNullable) {
return this;
}

return new ClassType(this.declaration, this.substitutions, isNullable);
}
}
Expand Down Expand Up @@ -384,6 +388,10 @@ export class EnumType extends NamedType<SdsEnum> {
}

override updateNullability(isNullable: boolean): EnumType {
if (this.isNullable === isNullable) {
return this;
}

return new EnumType(this.declaration, isNullable);
}
}
Expand Down Expand Up @@ -411,6 +419,10 @@ export class EnumVariantType extends NamedType<SdsEnumVariant> {
}

override updateNullability(isNullable: boolean): EnumVariantType {
if (this.isNullable === isNullable) {
return this;
}

return new EnumVariantType(this.declaration, isNullable);
}
}
Expand Down Expand Up @@ -446,6 +458,10 @@ export class TypeParameterType extends NamedType<SdsTypeParameter> {
}

override updateNullability(isNullable: boolean): TypeParameterType {
if (this.isNullable === isNullable) {
return this;
}

return new TypeParameterType(this.declaration, isNullable);
}
}
Expand Down Expand Up @@ -530,11 +546,22 @@ export class UnionType extends Type {
}

override unwrap(): Type {
if (this.possibleTypes.length === 1) {
return this.possibleTypes[0]!.unwrap();
// Flatten nested unions
const newPossibleTypes = this.possibleTypes.flatMap((type) => {
const unwrappedType = type.unwrap();
if (unwrappedType instanceof UnionType) {
return unwrappedType.possibleTypes;
} else {
return unwrappedType;
}
});

// Remove the outer union if there's only one type left
if (newPossibleTypes.length === 1) {
return newPossibleTypes[0]!;
}

return new UnionType(...this.possibleTypes.map((it) => it.unwrap()));
return new UnionType(...newPossibleTypes);
}

override updateNullability(isNullable: boolean): Type {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
Parameter,
TypeParameter,
} from '../helpers/nodeProperties.js';
import { Constant } from '../partialEvaluation/model.js';
import { Constant, NullConstant } from '../partialEvaluation/model.js';
import { SafeDsServices } from '../safe-ds-module.js';
import {
CallableType,
Expand Down Expand Up @@ -222,10 +222,16 @@ export class SafeDsTypeChecker {
private literalTypeIsAssignableTo(type: LiteralType, other: Type): boolean {
if (type.isNullable && !other.isNullable) {
return false;
} else if (type.constants.length === 0) {
// Empty literal types are equivalent to `Nothing` and assignable to any type
return true;
} else if (type.constants.every((it) => it === NullConstant)) {
// Literal types containing only `null` are equivalent to `Nothing?` and assignable to any nullable type
return other.isNullable;
}

if (other instanceof ClassType) {
if (other.equals(this.coreTypes.AnyOrNull)) {
if (other.equals(this.coreTypes.Any.updateNullability(type.isNullable))) {
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ export class SafeDsTypeComputer {
return UnknownType;
}

const unsubstitutedType = this.nodeTypeCache.get(this.getNodeId(node), () => this.doComputeType(node).unwrap());
const unsubstitutedType = this.nodeTypeCache.get(this.getNodeId(node), () =>
this.simplifyType(this.doComputeType(node)),
);
return unsubstitutedType.substituteTypeParameters(substitutions);
}

Expand Down Expand Up @@ -609,6 +611,88 @@ export class SafeDsTypeComputer {
return result;
}

// -----------------------------------------------------------------------------------------------------------------
// Simplify type
// -----------------------------------------------------------------------------------------------------------------

private simplifyType(type: Type): Type {
const unwrappedType = type.unwrap();

if (unwrappedType instanceof LiteralType) {
return this.simplifyLiteralType(unwrappedType);
} else if (unwrappedType instanceof UnionType) {
return this.simplifyUnionType(unwrappedType);
} else {
return unwrappedType;
}
}

private simplifyLiteralType(type: LiteralType): Type {
// Handle empty literal types
if (isEmpty(type.constants)) {
return this.coreTypes.Nothing;
}

// Remove duplicate constants
const uniqueConstants: Constant[] = [];
const knownConstants = new Set<String>();

for (const constant of type.constants) {
let key = constant.toString();

if (!knownConstants.has(key)) {
uniqueConstants.push(constant);
knownConstants.add(key);
}
}

// Apply other simplifications
if (uniqueConstants.length === 1 && uniqueConstants[0] === NullConstant) {
return this.coreTypes.NothingOrNull;
} else if (uniqueConstants.length < type.constants.length) {
return new LiteralType(...uniqueConstants);
} else {
return type;
}
}

private simplifyUnionType(type: UnionType): Type {
// Handle empty union types
if (isEmpty(type.possibleTypes)) {
return this.coreTypes.Nothing;
}

// Simplify possible types
const newPossibleTypes = type.possibleTypes.map((it) => this.simplifyType(it));

// Remove types that are subtypes of others. We do this back-to-front to keep the first occurrence of duplicate
// types. It's also makes splicing easier.
for (let i = newPossibleTypes.length - 1; i >= 0; i--) {
const currentType = newPossibleTypes[i]!;

for (let j = 0; j < newPossibleTypes.length; j++) {
if (i === j) {
continue;
}

let otherType = newPossibleTypes[j]!;
otherType = otherType.updateNullability(currentType.isNullable || otherType.isNullable);

if (this.typeChecker.isAssignableTo(currentType, otherType)) {
newPossibleTypes.splice(j, 1, otherType); // Update nullability
newPossibleTypes.splice(i, 1);
break;
}
}
}

if (newPossibleTypes.length === 1) {
return newPossibleTypes[0]!;
} else {
return new UnionType(...newPossibleTypes);
}
}

// -----------------------------------------------------------------------------------------------------------------
// Compute class types for literal types and their constants
// -----------------------------------------------------------------------------------------------------------------
Expand Down
12 changes: 12 additions & 0 deletions packages/safe-ds-lang/tests/language/typing/model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,18 @@ describe('type model', async () => {
type: new UnionType(new UnionType(new ClassType(class1, new Map(), false))),
expectedType: new ClassType(class1, new Map(), false),
},
{
type: new UnionType(
new UnionType(new ClassType(class1, new Map(), false), new ClassType(class2, new Map(), false)),
new UnionType(new EnumType(enum1, false), new EnumVariantType(enumVariant1, false)),
),
expectedType: new UnionType(
new ClassType(class1, new Map(), false),
new ClassType(class2, new Map(), false),
new EnumType(enum1, false),
new EnumVariantType(enumVariant1, false),
),
},
{
type: UnknownType,
expectedType: UnknownType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,26 @@ describe('SafeDsTypeChecker', async () => {
expected: false,
},
// Literal type to other
{
type1: new LiteralType(), // Empty literal type
type2: enumType1,
expected: true,
},
{
type1: new LiteralType(NullConstant),
type2: enumType1,
expected: false,
},
{
type1: new LiteralType(NullConstant),
type2: enumType1.updateNullability(true),
expected: true,
},
{
type1: new LiteralType(NullConstant, NullConstant),
type2: enumType1.updateNullability(true),
expected: true,
},
{
type1: new LiteralType(new IntConstant(1n)),
type2: enumType1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ import {
isSdsFunction,
isSdsModule,
} from '../../../../src/language/generated/ast.js';
import { getModuleMembers } from '../../../../src/language/helpers/nodeProperties.js';
import { createSafeDsServicesWithBuiltins } from '../../../../src/language/index.js';
import { createSafeDsServicesWithBuiltins, getModuleMembers } from '../../../../src/language/index.js';
import { BooleanConstant, IntConstant, NullConstant } from '../../../../src/language/partialEvaluation/model.js';
import {
ClassType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pipeline myPipeline {
// $TEST$ serialization literal<1>
val intLiteral = »1«;

// $TEST$ serialization literal<null>
// $TEST$ serialization Nothing?
val nullLiteral = »null«;

// $TEST$ serialization literal<"myString">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ segment mySegment1(p: C<Int>) {
»nullableC()?.nonNullableMember«;
// $TEST$ serialization Int?
»nullableC()?.nullableMember«;
// $TEST$ serialization union<() -> (r: Int), literal<null>>
// $TEST$ serialization union<() -> (r: Int), Nothing?>
»nullableC()?.method«;
}

Expand Down Expand Up @@ -71,6 +71,6 @@ segment mySegment2(p: D) {
»nullableD()?.nonNullableMember«;
// $TEST$ serialization Int?
»nullableD()?.nullableMember«;
// $TEST$ serialization union<() -> (r: Int), literal<null>>
// $TEST$ serialization union<() -> (r: Int), Nothing?>
»nullableD()?.method«;
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ pipeline myPipeline {
»nullableC()?.nonNullableMember«;
// $TEST$ equivalence_class nullableMember
»nullableC()?.nullableMember«;
// $TEST$ serialization union<() -> (r: Int), literal<null>>
// $TEST$ serialization union<() -> (r: Int), Nothing?>
»nullableC()?.method«;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package tests.typing.simplification.removeDuplicateConstantsFromLiteralTypes

class C(
// $TEST$ serialization literal<1>
p1: »literal<1, 1>«,

// $TEST$ serialization literal<1, 2>
p2: »literal<1, 2>«,

// $TEST$ serialization literal<1, 1.0>
p3: »literal<1, 1.0>«,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package tests.typing.simplification.removeUnneededEntriesFromUnionTypes

class C(
// $TEST$ serialization Int
p1: »union<Int, Int>«,

// $TEST$ serialization union<Int, String>
p2: »union<Int, String, Int>«,


// $TEST$ serialization Number
p3: »union<Int, Number>«,

// $TEST$ serialization Number
p4: »union<Number, Int>«,

// $TEST$ serialization Number?
p5: »union<Number, Int?>«,

// $TEST$ serialization Any
p6: »union<Int, Number, Any>«,

// $TEST$ serialization Any
p7: »union<Any, Number, Int>«,

// $TEST$ serialization Any?
p8: »union<Int, Number?, Any>«,


// $TEST$ serialization union<Int, String>
p9: »union<Int, String>«,

// $TEST$ serialization union<Int, String?>
p10: »union<Int, String?>«,

// $TEST$ serialization union<Number, String>
p11: »union<Int, Number, String>«,

// $TEST$ serialization union<Number, String>
p12: »union<Number, Int, String>«,

// $TEST$ serialization Any
p13: »union<Int, String, Any>«,

// $TEST$ serialization Any?
p14: »union<Int, String?, Any>«,

// $TEST$ serialization Any
p15: »union<Any, String, Int>«,

// $TEST$ serialization Any?
p16: »union<Any, String?, Int>«,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package tests.typing.simplification.replaceEmptyLiteralTypesWithNothing

class C(
// $TEST$ serialization Nothing
p1: »literal<>«
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package tests.typing.simplification.replaceEmptyUnionTypesWithNothing

class C(
// $TEST$ serialization Nothing
p1: »union<>«
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package tests.typing.simplification.replaceLiteralTypesThatAllowOnlyNullWithNothingNullable

class C(
// $TEST$ serialization Nothing?
p1: »literal<null>«,

// $TEST$ serialization Nothing?
p2: »literal<null, null>«
)
Loading