diff --git a/runtime/sema/checker.go b/runtime/sema/checker.go index a049b545cf..6779e8437a 100644 --- a/runtime/sema/checker.go +++ b/runtime/sema/checker.go @@ -937,40 +937,45 @@ func CheckIntersectionType( // The intersections may not have clashing members - // TODO: also include interface conformances' members - // once interfaces can have conformances + checkClashingMember := func(interfaceType *InterfaceType) { + interfaceType.Members.Foreach(func(name string, member *Member) { - interfaceType.Members.Foreach(func(name string, member *Member) { - if previousDeclaringInterfaceType, ok := memberSet[name]; ok { + if previousDeclaringInterfaceType, ok := memberSet[name]; ok { - // If there is an overlap in members, ensure the members have the same type + // If there is an overlap in members, ensure the members have the same type - memberType := member.TypeAnnotation.Type + memberType := member.TypeAnnotation.Type - prevMemberType, ok := previousDeclaringInterfaceType.Members.Get(name) - if !ok { - panic(errors.NewUnreachableError()) - } + prevMemberType, ok := previousDeclaringInterfaceType.Members.Get(name) + if !ok { + panic(errors.NewUnreachableError()) + } - previousMemberType := prevMemberType.TypeAnnotation.Type + previousMemberType := prevMemberType.TypeAnnotation.Type - if !memberType.IsInvalidType() && - !previousMemberType.IsInvalidType() && - !memberType.Equal(previousMemberType) { + if !memberType.IsInvalidType() && + !previousMemberType.IsInvalidType() && + !memberType.Equal(previousMemberType) { - report(func(t *ast.IntersectionType) error { - return &IntersectionMemberClashError{ - Name: name, - RedeclaringType: interfaceType, - OriginalDeclaringType: previousDeclaringInterfaceType, - Range: ast.NewRangeFromPositioned(memoryGauge, t.Types[i]), - } - }) + report(func(t *ast.IntersectionType) error { + return &IntersectionMemberClashError{ + Name: name, + RedeclaringType: interfaceType, + OriginalDeclaringType: previousDeclaringInterfaceType, + Range: ast.NewRangeFromPositioned(memoryGauge, t.Types[i]), + } + }) + } + } else { + memberSet[name] = interfaceType } - } else { - memberSet[name] = interfaceType - } - }) + }) + } + + checkClashingMember(interfaceType) + + interfaceType.EffectiveInterfaceConformanceSet(). + ForEach(checkClashingMember) } // If no intersection type is given, infer `AnyResource`/`AnyStruct` diff --git a/runtime/tests/checker/intersection_test.go b/runtime/tests/checker/intersection_test.go index d3395b740c..aaf602fd35 100644 --- a/runtime/tests/checker/intersection_test.go +++ b/runtime/tests/checker/intersection_test.go @@ -19,6 +19,7 @@ package checker import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -442,6 +443,97 @@ func TestCheckIntersectionTypeMemberAccess(t *testing.T) { }) } +func TestCheckIntersectionTypeWithInheritanceMemberClash(t *testing.T) { + + t.Parallel() + + const firstMember = `let n: Int` + const secondMember = `let n: Bool` + + test := func( + memberInA, memberInC bool, + firstType, secondType string, + ) { + + testName := fmt.Sprintf( + "memberInA: %v, memberInC: %v, firstType: %s, secondType: %s", + memberInA, memberInC, firstType, secondType, + ) + + t.Run(testName, func(t *testing.T) { + t.Parallel() + + bodyA := "" + bodyB := "" + bodyC := "" + bodyD := "" + + if memberInA { + bodyA = firstMember + } else { + bodyB = firstMember + } + + if memberInC { + bodyC = secondMember + } else { + bodyD = secondMember + } + + _, err := ParseAndCheck(t, + fmt.Sprintf( + ` + struct interface A { + %s + } + + struct interface B: A { + %s + } + + struct interface C { + %s + } + + struct interface D: C { + %s + } + + fun test(_ v: {%s, %s}) {} + `, + bodyA, + bodyB, + bodyC, + bodyD, + firstType, + secondType, + ), + ) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.IntersectionMemberClashError{}, errs[0]) + }) + } + + for _, memberInA := range []bool{true, false} { + for _, memberInC := range []bool{true, false} { + for _, firstType := range []string{"A", "B"} { + for _, secondType := range []string{"C", "D"} { + + if (firstType == "A" && !memberInA) || + (secondType == "C" && !memberInC) { + + continue + } + + test(memberInA, memberInC, firstType, secondType) + } + } + } + } +} + func TestCheckIntersectionTypeSubtyping(t *testing.T) { t.Parallel()