diff --git a/migrations/statictypes/account_type_migration_test.go b/migrations/statictypes/account_type_migration_test.go index c166fcdba5..7441da769b 100644 --- a/migrations/statictypes/account_type_migration_test.go +++ b/migrations/statictypes/account_type_migration_test.go @@ -314,7 +314,7 @@ func TestAccountTypeInTypeValueMigration(t *testing.T) { interpreter.NewVariableSizedStaticType(nil, authAccountReferenceType), ), }, - "interface": { + "non_intersection_interface": { storedType: interpreter.NewInterfaceStaticType( nil, nil, @@ -325,6 +325,38 @@ func TestAccountTypeInTypeValueMigration(t *testing.T) { fooBarQualifiedIdentifier, ), ), + expectedType: interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + interpreter.NewInterfaceStaticType( + nil, + nil, + fooBarQualifiedIdentifier, + common.NewTypeIDFromQualifiedName( + nil, + fooAddressLocation, + fooBarQualifiedIdentifier, + ), + ), + }, + ), + }, + "intersection_interface": { + storedType: interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + interpreter.NewInterfaceStaticType( + nil, + nil, + fooBarQualifiedIdentifier, + common.NewTypeIDFromQualifiedName( + nil, + fooAddressLocation, + fooBarQualifiedIdentifier, + ), + ), + }, + ), expectedType: nil, }, "composite": { @@ -437,7 +469,23 @@ func TestAccountTypeInTypeValueMigration(t *testing.T) { typeValue := value.(interpreter.TypeValue) if actualIntersectionType, ok := typeValue.Type.(*interpreter.IntersectionStaticType); ok { expectedIntersectionType := testCase.expectedType.(*interpreter.IntersectionStaticType) - assert.True(t, actualIntersectionType.LegacyType.Equal(expectedIntersectionType.LegacyType)) + + if actualIntersectionType.LegacyType != nil { + assert.True(t, + actualIntersectionType.LegacyType. + Equal(expectedIntersectionType.LegacyType), + ) + } else if expectedIntersectionType.LegacyType != nil { + assert.True(t, + expectedIntersectionType.LegacyType. + Equal(actualIntersectionType.LegacyType), + ) + } else { + assert.Equal(t, + expectedIntersectionType.LegacyType, + actualIntersectionType.LegacyType, + ) + } } } else { expectedValue = interpreter.NewTypeValue(nil, testCase.storedType) diff --git a/migrations/statictypes/composite_type_migration_test.go b/migrations/statictypes/composite_type_migration_test.go index 61d0794f5e..6df0ed7b69 100644 --- a/migrations/statictypes/composite_type_migration_test.go +++ b/migrations/statictypes/composite_type_migration_test.go @@ -43,7 +43,7 @@ func TestCompositeAndInterfaceTypeMigration(t *testing.T) { expectedType interpreter.StaticType } - newCompositeType := func() interpreter.StaticType { + newCompositeType := func() *interpreter.CompositeStaticType { return interpreter.NewCompositeStaticType( nil, nil, @@ -56,7 +56,7 @@ func TestCompositeAndInterfaceTypeMigration(t *testing.T) { ) } - newInterfaceType := func() interpreter.StaticType { + newInterfaceType := func() *interpreter.InterfaceStaticType { return interpreter.NewInterfaceStaticType( nil, nil, @@ -72,8 +72,13 @@ func TestCompositeAndInterfaceTypeMigration(t *testing.T) { testCases := map[string]testCase{ // base cases "compositeToInterface": { - storedType: newCompositeType(), - expectedType: newInterfaceType(), + storedType: newCompositeType(), + expectedType: interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + newInterfaceType(), + }, + ), }, "interfaceToComposite": { storedType: newInterfaceType(), diff --git a/migrations/statictypes/intersection_type_migration_test.go b/migrations/statictypes/intersection_type_migration_test.go index 26afa9f343..eec47b8792 100644 --- a/migrations/statictypes/intersection_type_migration_test.go +++ b/migrations/statictypes/intersection_type_migration_test.go @@ -332,7 +332,7 @@ func TestIntersectionTypeMigration(t *testing.T) { ), }, // interface - "interface": { + "non_intersection_interface": { storedType: interpreter.NewInterfaceStaticType( nil, nil, @@ -343,6 +343,39 @@ func TestIntersectionTypeMigration(t *testing.T) { fooBarQualifiedIdentifier, ), ), + expectedType: interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + interpreter.NewInterfaceStaticType( + nil, + nil, + "Foo.Bar", + common.NewTypeIDFromQualifiedName( + nil, + fooAddressLocation, + fooBarQualifiedIdentifier, + ), + ), + }, + ), + }, + "intersection_interface": { + storedType: interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + interpreter.NewInterfaceStaticType( + nil, + nil, + "Foo.Bar", + common.NewTypeIDFromQualifiedName( + nil, + fooAddressLocation, + fooBarQualifiedIdentifier, + ), + ), + }, + ), + expectedType: nil, }, // composite "composite": { @@ -356,6 +389,7 @@ func TestIntersectionTypeMigration(t *testing.T) { fooBarQualifiedIdentifier, ), ), + expectedType: nil, }, } @@ -1187,3 +1221,272 @@ func TestIntersectionTypeMigrationWithInterfaceTypeConverter(t *testing.T) { } } } + +func TestIntersectionTypeMigrationWithTypeConverters(t *testing.T) { + t.Parallel() + + migrate := func( + t *testing.T, + staticTypeMigration *StaticTypeMigration, + input interpreter.StaticType, + ) interpreter.StaticType { + + // Store values + + ledger := NewTestLedger(nil, nil) + storage := runtime.NewStorage(ledger, nil) + + inter, err := interpreter.NewInterpreter( + nil, + utils.TestLocation, + &interpreter.Config{ + Storage: storage, + AtreeValueValidationEnabled: false, + AtreeStorageValidationEnabled: true, + }, + ) + require.NoError(t, err) + + const testPathDomain = common.PathDomainStorage + const testPathIdentifier = "test_type_value" + + storeTypeValue( + inter, + testAddress, + testPathDomain, + testPathIdentifier, + input, + ) + + err = storage.Commit(inter, true) + require.NoError(t, err) + + // Migrate + + migration := migrations.NewStorageMigration(inter, storage) + + reporter := newTestReporter() + + migration.Migrate( + &migrations.AddressSliceIterator{ + Addresses: []common.Address{ + testAddress, + }, + }, + migration.NewValueMigrationsPathMigrator( + reporter, + staticTypeMigration, + ), + ) + + err = migration.Commit() + require.NoError(t, err) + + key := struct { + interpreter.StorageKey + interpreter.StorageMapKey + }{ + StorageKey: interpreter.StorageKey{ + Address: testAddress, + Key: testPathDomain.Identifier(), + }, + StorageMapKey: interpreter.StringStorageMapKey(testPathIdentifier), + } + + assert.Contains(t, reporter.migrated, key) + + storageMap := storage.GetStorageMap(testAddress, testPathDomain.Identifier(), false) + require.NotNil(t, storageMap) + require.Equal(t, uint64(1), storageMap.Count()) + + value := storageMap.ReadValue(nil, interpreter.StringStorageMapKey(testPathIdentifier)) + require.NotNil(t, value) + + require.IsType(t, interpreter.TypeValue{}, value) + + return value.(interpreter.TypeValue).Type + } + + const fooCompositeQualifiedIdentifierA = "Foo.A" + const fooCompositeQualifiedIdentifierB = "Foo.B" + + fooACompositeType := interpreter.NewCompositeStaticType( + nil, + fooAddressLocation, + fooCompositeQualifiedIdentifierA, + fooAddressLocation.TypeID(nil, fooCompositeQualifiedIdentifierA), + ) + + fooBCompositeType := interpreter.NewCompositeStaticType( + nil, + fooAddressLocation, + fooCompositeQualifiedIdentifierB, + fooAddressLocation.TypeID(nil, fooCompositeQualifiedIdentifierB), + ) + + const fooInterfaceQualifiedIdentifierC = "Foo.C" + const fooInterfaceQualifiedIdentifierD = "Foo.D" + + fooCInterfaceType := interpreter.NewInterfaceStaticType( + nil, + fooAddressLocation, + fooInterfaceQualifiedIdentifierC, + fooAddressLocation.TypeID(nil, fooInterfaceQualifiedIdentifierC), + ) + fooDInterfaceType := interpreter.NewInterfaceStaticType( + nil, + fooAddressLocation, + fooInterfaceQualifiedIdentifierD, + fooAddressLocation.TypeID(nil, fooInterfaceQualifiedIdentifierD), + ) + + t.Run("composite type converter", func(t *testing.T) { + t.Parallel() + + t.Run("return non-interface", func(t *testing.T) { + t.Parallel() + + staticTypeMigration := NewStaticTypeMigration(). + WithCompositeTypeConverter(func(staticType *interpreter.CompositeStaticType) interpreter.StaticType { + if staticType == fooACompositeType { + return fooBCompositeType + } + return nil + }) + + actual := migrate(t, staticTypeMigration, fooACompositeType) + assert.Equal(t, fooBCompositeType, actual) + }) + + t.Run("return interface", func(t *testing.T) { + t.Parallel() + + staticTypeMigration := NewStaticTypeMigration(). + WithCompositeTypeConverter(func(staticType *interpreter.CompositeStaticType) interpreter.StaticType { + if staticType == fooACompositeType { + // NOTE: return interface type as-is, not wrapped in intersection type, + // to test if it gets wrapped properly into an intersection type + return fooCInterfaceType + } + return nil + }) + + actual := migrate(t, staticTypeMigration, fooACompositeType) + assert.Equal(t, + interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + fooCInterfaceType, + }, + ), + actual, + ) + }) + + t.Run("return intersection", func(t *testing.T) { + t.Parallel() + + staticTypeMigration := NewStaticTypeMigration(). + WithCompositeTypeConverter(func(staticType *interpreter.CompositeStaticType) interpreter.StaticType { + if staticType == fooACompositeType { + // NOTE: return interface type wrapped in intersection type, + // to test if it does not get re-wrapped into an intersection type + return interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + fooCInterfaceType, + }, + ) + } + return nil + }) + + actual := migrate(t, staticTypeMigration, fooACompositeType) + assert.Equal(t, + interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + fooCInterfaceType, + }, + ), + actual, + ) + + }) + }) + + t.Run("interface type converter", func(t *testing.T) { + t.Parallel() + + t.Run("return non-interface", func(t *testing.T) { + t.Parallel() + + staticTypeMigration := NewStaticTypeMigration(). + WithInterfaceTypeConverter(func(staticType *interpreter.InterfaceStaticType) interpreter.StaticType { + if staticType == fooCInterfaceType { + return fooBCompositeType + } + return nil + }) + + actual := migrate(t, staticTypeMigration, fooCInterfaceType) + assert.Equal(t, fooBCompositeType, actual) + }) + + t.Run("return interface", func(t *testing.T) { + t.Parallel() + + staticTypeMigration := NewStaticTypeMigration(). + WithInterfaceTypeConverter(func(staticType *interpreter.InterfaceStaticType) interpreter.StaticType { + if staticType == fooCInterfaceType { + // NOTE: return interface type as-is, not wrapped in intersection type, + // to test if it gets wrapped properly into an intersection type + return fooDInterfaceType + } + return nil + }) + + actual := migrate(t, staticTypeMigration, fooCInterfaceType) + assert.Equal(t, + interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + fooDInterfaceType, + }, + ), + actual, + ) + + }) + + t.Run("return intersection", func(t *testing.T) { + t.Parallel() + + staticTypeMigration := NewStaticTypeMigration(). + WithInterfaceTypeConverter(func(staticType *interpreter.InterfaceStaticType) interpreter.StaticType { + if staticType == fooCInterfaceType { + // NOTE: return interface type wrapped in intersection type, + // to test if it does not get re-wrapped into an intersection type + return interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + fooDInterfaceType, + }, + ) + } + return nil + }) + + actual := migrate(t, staticTypeMigration, fooCInterfaceType) + assert.Equal(t, + interpreter.NewIntersectionStaticType( + nil, + []*interpreter.InterfaceStaticType{ + fooDInterfaceType, + }, + ), + actual, + ) + }) + }) +} diff --git a/migrations/statictypes/statictype_migration.go b/migrations/statictypes/statictype_migration.go index 7aae9c28de..aa006b44be 100644 --- a/migrations/statictypes/statictype_migration.go +++ b/migrations/statictypes/statictype_migration.go @@ -33,8 +33,8 @@ type StaticTypeMigration struct { interfaceTypeConverter InterfaceTypeConverterFunc } -type CompositeTypeConverterFunc func(staticType *interpreter.CompositeStaticType) interpreter.StaticType -type InterfaceTypeConverterFunc func(staticType *interpreter.InterfaceStaticType) interpreter.StaticType +type CompositeTypeConverterFunc func(*interpreter.CompositeStaticType) interpreter.StaticType +type InterfaceTypeConverterFunc func(*interpreter.InterfaceStaticType) interpreter.StaticType var _ migrations.ValueMigration = &StaticTypeMigration{} @@ -66,21 +66,21 @@ func (m *StaticTypeMigration) Migrate( ) (newValue interpreter.Value, err error) { switch value := value.(type) { case interpreter.TypeValue: - convertedType := m.maybeConvertStaticType(value.Type) + convertedType := m.maybeConvertStaticType(value.Type, nil) if convertedType == nil { return } return interpreter.NewTypeValue(nil, convertedType), nil case *interpreter.IDCapabilityValue: - convertedBorrowType := m.maybeConvertStaticType(value.BorrowType) + convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil) if convertedBorrowType == nil { return } return interpreter.NewUnmeteredCapabilityValue(value.ID, value.Address, convertedBorrowType), nil case *interpreter.PathCapabilityValue: //nolint:staticcheck - convertedBorrowType := m.maybeConvertStaticType(value.BorrowType) + convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil) if convertedBorrowType == nil { return } @@ -91,7 +91,7 @@ func (m *StaticTypeMigration) Migrate( }, nil case interpreter.PathLinkValue: //nolint:staticcheck - convertedBorrowType := m.maybeConvertStaticType(value.Type) + convertedBorrowType := m.maybeConvertStaticType(value.Type, nil) if convertedBorrowType == nil { return } @@ -101,7 +101,7 @@ func (m *StaticTypeMigration) Migrate( }, nil case *interpreter.AccountCapabilityControllerValue: - convertedBorrowType := m.maybeConvertStaticType(value.BorrowType) + convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil) if convertedBorrowType == nil { return } @@ -109,7 +109,7 @@ func (m *StaticTypeMigration) Migrate( return interpreter.NewUnmeteredAccountCapabilityControllerValue(borrowType, value.CapabilityID), nil case *interpreter.StorageCapabilityControllerValue: - convertedBorrowType := m.maybeConvertStaticType(value.BorrowType) + convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil) if convertedBorrowType == nil { return } @@ -124,23 +124,23 @@ func (m *StaticTypeMigration) Migrate( return } -func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.StaticType) interpreter.StaticType { +func (m *StaticTypeMigration) maybeConvertStaticType(staticType, parentType interpreter.StaticType) interpreter.StaticType { switch staticType := staticType.(type) { case *interpreter.ConstantSizedStaticType: - convertedType := m.maybeConvertStaticType(staticType.Type) + convertedType := m.maybeConvertStaticType(staticType.Type, staticType) if convertedType != nil { return interpreter.NewConstantSizedStaticType(nil, convertedType, staticType.Size) } case *interpreter.VariableSizedStaticType: - convertedType := m.maybeConvertStaticType(staticType.Type) + convertedType := m.maybeConvertStaticType(staticType.Type, staticType) if convertedType != nil { return interpreter.NewVariableSizedStaticType(nil, convertedType) } case *interpreter.DictionaryStaticType: - convertedKeyType := m.maybeConvertStaticType(staticType.KeyType) - convertedValueType := m.maybeConvertStaticType(staticType.ValueType) + convertedKeyType := m.maybeConvertStaticType(staticType.KeyType, staticType) + convertedValueType := m.maybeConvertStaticType(staticType.ValueType, staticType) if convertedKeyType != nil && convertedValueType != nil { return interpreter.NewDictionaryStaticType(nil, convertedKeyType, convertedValueType) } @@ -152,7 +152,7 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.Stat } case *interpreter.CapabilityStaticType: - convertedBorrowType := m.maybeConvertStaticType(staticType.BorrowType) + convertedBorrowType := m.maybeConvertStaticType(staticType.BorrowType, staticType) if convertedBorrowType != nil { return interpreter.NewCapabilityStaticType(nil, convertedBorrowType) } @@ -164,7 +164,7 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.Stat var convertedInterfaceType bool for _, interfaceStaticType := range staticType.Types { - convertedType := m.maybeConvertStaticType(interfaceStaticType) + convertedType := m.maybeConvertStaticType(interfaceStaticType, staticType) // lazily allocate the slice if convertedInterfaceTypes == nil { @@ -194,7 +194,7 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.Stat legacyType := staticType.LegacyType var convertedLegacyType interpreter.StaticType if legacyType != nil { - convertedLegacyType = m.maybeConvertStaticType(legacyType) + convertedLegacyType = m.maybeConvertStaticType(legacyType, staticType) } // If the interface set has at least two items, @@ -214,14 +214,14 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.Stat } case *interpreter.OptionalStaticType: - convertedInnerType := m.maybeConvertStaticType(staticType.Type) + convertedInnerType := m.maybeConvertStaticType(staticType.Type, staticType) if convertedInnerType != nil { return interpreter.NewOptionalStaticType(nil, convertedInnerType) } case *interpreter.ReferenceStaticType: // TODO: Reference of references must not be allowed? - convertedReferencedType := m.maybeConvertStaticType(staticType.ReferencedType) + convertedReferencedType := m.maybeConvertStaticType(staticType.ReferencedType, staticType) if convertedReferencedType != nil { switch convertedReferencedType { @@ -232,7 +232,11 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.Stat return convertedReferencedType default: - return interpreter.NewReferenceStaticType(nil, staticType.Authorization, convertedReferencedType) + return interpreter.NewReferenceStaticType( + nil, + staticType.Authorization, + convertedReferencedType, + ) } } @@ -240,22 +244,71 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.Stat // Non-storable case *interpreter.CompositeStaticType: + var convertedType interpreter.StaticType compositeTypeConverter := m.compositeTypeConverter if compositeTypeConverter != nil { - return compositeTypeConverter(staticType) + convertedType = compositeTypeConverter(staticType) } + // Interface types need to be placed in intersection types. + // If the composite type was converted to an interface type, + // and if the parent type is not an intersection type, + // then the converted interface type must be placed in an intersection type + if convertedInterfaceType, ok := convertedType.(*interpreter.InterfaceStaticType); ok { + if _, ok := parentType.(*interpreter.IntersectionStaticType); !ok { + convertedType = interpreter.NewIntersectionStaticType( + nil, []*interpreter.InterfaceStaticType{ + convertedInterfaceType, + }, + ) + } + } + + return convertedType + case *interpreter.InterfaceStaticType: + var convertedType interpreter.StaticType interfaceTypeConverter := m.interfaceTypeConverter if interfaceTypeConverter != nil { - return interfaceTypeConverter(staticType) + convertedType = interfaceTypeConverter(staticType) } + // Interface types need to be placed in intersection types + if _, ok := parentType.(*interpreter.IntersectionStaticType); !ok { + // If the interface type was not converted to another type, + // and given the parent type is not an intersection type, + // then the original interface type must be placed in an intersection type + if convertedType == nil { + convertedType = interpreter.NewIntersectionStaticType( + nil, []*interpreter.InterfaceStaticType{ + staticType, + }, + ) + } else { + // If the interface type was converted to another type, + // it may have been converted to + // - a different kind of type, e.g. a composite type, + // in which case the converted type should be returned as-is + // - another interface type – + // given the parent type is not an intersection type, + // then the converted interface type must be placed in an intersection type + if convertedInterfaceType, ok := convertedType.(*interpreter.InterfaceStaticType); ok { + convertedType = interpreter.NewIntersectionStaticType( + nil, []*interpreter.InterfaceStaticType{ + convertedInterfaceType, + }, + ) + } + } + } + + return convertedType + case dummyStaticType: // This is for testing the migration. - // i.e: wrapper was only to make it possible to use as a dictionary-key. + // i.e: the dummyStaticType wrapper was only introduced to make it possible to use the type as a dictionary key. // Ignore the wrapper, and continue with the inner type. - return m.maybeConvertStaticType(staticType.PrimitiveStaticType) + return m.maybeConvertStaticType(staticType.PrimitiveStaticType, staticType) case interpreter.PrimitiveStaticType: // Is it safe to do so? diff --git a/runtime/account_test.go b/runtime/account_test.go index 417fe775be..54d37cbf53 100644 --- a/runtime/account_test.go +++ b/runtime/account_test.go @@ -1656,7 +1656,7 @@ func TestRuntimeAuthAccountContracts(t *testing.T) { transaction { prepare(acc: &Account) { - let hello = acc.contracts.borrow<&HelloInterface>(name: "Hello") + let hello = acc.contracts.borrow<&{HelloInterface}>(name: "Hello") assert(hello?.hello() == "Hello!") } } @@ -1747,7 +1747,7 @@ func TestRuntimeAuthAccountContracts(t *testing.T) { transaction { prepare(acc: &Account) { - let hello = acc.contracts.borrow<&HelloInterface>(name: "Hello") + let hello = acc.contracts.borrow<&{HelloInterface}>(name: "Hello") assert(hello == nil) } } diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 47ef402f1d..35347c463f 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -3391,15 +3391,6 @@ func init() { ), ) - defineBaseValue( - BaseActivation, - sema.InterfaceTypeFunctionName, - NewUnmeteredHostFunctionValue( - sema.InterfaceTypeFunctionType, - interfaceTypeFunction, - ), - ) - defineBaseValue( BaseActivation, sema.FunctionTypeFunctionName, @@ -3536,27 +3527,6 @@ func compositeTypeFunction(invocation Invocation) Value { ) } -func interfaceTypeFunction(invocation Invocation) Value { - typeIDValue, ok := invocation.Arguments[0].(*StringValue) - if !ok { - panic(errors.NewUnreachableError()) - } - typeID := typeIDValue.Str - - interfaceType, err := lookupInterface(invocation.Interpreter, typeID) - if err != nil { - return Nil - } - - return NewSomeValueNonCopying( - invocation.Interpreter, - NewTypeValue( - invocation.Interpreter, - ConvertSemaToStaticType(invocation.Interpreter, interfaceType), - ), - ) -} - func functionTypeFunction(invocation Invocation) Value { interpreter := invocation.Interpreter diff --git a/runtime/sema/runtime_type_constructors.go b/runtime/sema/runtime_type_constructors.go index 7b46c6213c..893f5d6f5a 100644 --- a/runtime/sema/runtime_type_constructors.go +++ b/runtime/sema/runtime_type_constructors.go @@ -110,20 +110,6 @@ var CompositeTypeFunctionType = NewSimpleFunctionType( OptionalMetaTypeAnnotation, ) -const InterfaceTypeFunctionName = "InterfaceType" - -var InterfaceTypeFunctionType = NewSimpleFunctionType( - FunctionPurityView, - []Parameter{ - { - Label: ArgumentLabelNotRequired, - Identifier: "identifier", - TypeAnnotation: StringTypeAnnotation, - }, - }, - OptionalMetaTypeAnnotation, -) - const FunctionTypeFunctionName = "FunctionType" var FunctionTypeFunctionType = NewSimpleFunctionType( @@ -245,13 +231,6 @@ var runtimeTypeConstructors = []*RuntimeTypeConstructor{ Returns nil if the identifier does not correspond to any composite type.`, }, - { - Name: InterfaceTypeFunctionName, - Value: InterfaceTypeFunctionType, - DocString: `Creates a run-time type representing the interface type associated with the given type identifier. - Returns nil if the identifier does not correspond to any interface type.`, - }, - { Name: FunctionTypeFunctionName, Value: FunctionTypeFunctionType, diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 2a248bb082..87ffa41065 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -5906,16 +5906,9 @@ func (*InterfaceType) TypeAnnotationState() TypeAnnotationState { } func (t *InterfaceType) RewriteWithIntersectionTypes() (Type, bool) { - switch t.CompositeKind { - case common.CompositeKindResource, common.CompositeKindStructure: - return &IntersectionType{ - Types: []*InterfaceType{t}, - }, true - - default: - return t, false - } - + return &IntersectionType{ + Types: []*InterfaceType{t}, + }, true } func (*InterfaceType) Unify( diff --git a/runtime/tests/checker/access_test.go b/runtime/tests/checker/access_test.go index 5c464dc5c7..185787fb4d 100644 --- a/runtime/tests/checker/access_test.go +++ b/runtime/tests/checker/access_test.go @@ -28,7 +28,6 @@ import ( "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/sema" - . "github.com/onflow/cadence/runtime/tests/utils" ) func expectSuccess(t *testing.T, err error) { @@ -772,7 +771,7 @@ func TestCheckAccessInterfaceFunction(t *testing.T) { if compositeKind == common.CompositeKindContract { identifier = "TestImpl" } else { - interfaceType := AsInterfaceType("Test", compositeKind) + interfaceType := "{Test}" setupCode = fmt.Sprintf( `let test: %[1]s%[2]s %[3]s %[4]s TestImpl%[5]s`, @@ -990,7 +989,7 @@ func TestCheckAccessInterfaceFieldRead(t *testing.T) { if compositeKind == common.CompositeKindContract { identifier = "TestImpl" } else { - interfaceType := AsInterfaceType("Test", compositeKind) + interfaceType := "{Test}" setupCode = fmt.Sprintf( `let test: %[1]s%[2]s %[3]s %[4]s TestImpl%[5]s`, @@ -1237,7 +1236,7 @@ func TestCheckAccessInterfaceFieldWrite(t *testing.T) { identifier = "TestImpl" } else { - interfaceType := AsInterfaceType("Test", compositeKind) + interfaceType := "{Test}" setupCode = fmt.Sprintf( `let test: %[1]s%[2]s %[3]s %[4]s TestImpl%[5]s`, diff --git a/runtime/tests/checker/arrays_dictionaries_test.go b/runtime/tests/checker/arrays_dictionaries_test.go index 79d2d2b656..00eee40d08 100644 --- a/runtime/tests/checker/arrays_dictionaries_test.go +++ b/runtime/tests/checker/arrays_dictionaries_test.go @@ -30,7 +30,6 @@ import ( "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/sema" - . "github.com/onflow/cadence/runtime/tests/utils" ) func TestCheckDictionary(t *testing.T) { @@ -1464,24 +1463,18 @@ func TestCheckArraySubtyping(t *testing.T) { t.Run(kind.Keyword(), func(t *testing.T) { - body := "{}" - if kind == common.CompositeKindEvent { - body = "()" - } - - interfaceType := AsInterfaceType("I", kind) + interfaceType := "{I}" _, err := ParseAndCheck(t, fmt.Sprintf( ` - %[1]s interface I %[2]s - %[1]s S: I %[2]s + %[1]s interface I {} + %[1]s S: I {} - let xs: %[3]s[S] %[4]s [] - let ys: %[3]s[%[5]s] %[4]s xs + let xs: %[2]s[S] %[3]s [] + let ys: %[2]s[%[4]s] %[3]s xs `, kind.Keyword(), - body, kind.Annotation(), kind.TransferOperator(), interfaceType, @@ -1523,8 +1516,6 @@ func TestCheckDictionarySubtyping(t *testing.T) { body = "()" } - interfaceType := AsInterfaceType("I", kind) - _, err := ParseAndCheck(t, fmt.Sprintf( ` @@ -1532,13 +1523,12 @@ func TestCheckDictionarySubtyping(t *testing.T) { %[1]s S: I %[2]s let xs: %[3]s{String: S} %[4]s {} - let ys: %[3]s{String: %[5]s} %[4]s xs + let ys: %[3]s{String: {I}} %[4]s xs `, kind.Keyword(), body, kind.Annotation(), kind.TransferOperator(), - interfaceType, ), ) diff --git a/runtime/tests/checker/attachments_test.go b/runtime/tests/checker/attachments_test.go index 9aab821faa..0289e0803f 100644 --- a/runtime/tests/checker/attachments_test.go +++ b/runtime/tests/checker/attachments_test.go @@ -2583,7 +2583,7 @@ func TestCheckAttachmentAnyAttachmentTypes(t *testing.T) { }, { setupCode: "contract interface CI {}", - subType: "CI", + subType: "{CI}", expectedSuccess: false, }, } diff --git a/runtime/tests/checker/composite_test.go b/runtime/tests/checker/composite_test.go index 91671b4c50..a57ac1445c 100644 --- a/runtime/tests/checker/composite_test.go +++ b/runtime/tests/checker/composite_test.go @@ -29,7 +29,6 @@ import ( "github.com/onflow/cadence/runtime/errors" "github.com/onflow/cadence/runtime/parser" "github.com/onflow/cadence/runtime/sema" - . "github.com/onflow/cadence/runtime/tests/utils" ) func TestCheckInvalidCompositeRedeclaringType(t *testing.T) { @@ -1963,14 +1962,14 @@ func TestCheckMutualTypeUseTopLevel(t *testing.T) { firstTypeAnnotation := "A" if firstIsInterface { firstInterfaceKeyword = "interface" - firstTypeAnnotation = AsInterfaceType("A", firstKind) + firstTypeAnnotation = "{A}" } secondInterfaceKeyword := "" secondTypeAnnotation := "B" if secondIsInterface { secondInterfaceKeyword = "interface" - secondTypeAnnotation = AsInterfaceType("B", secondKind) + secondTypeAnnotation = "{B}" } testName := fmt.Sprintf( diff --git a/runtime/tests/checker/interface_test.go b/runtime/tests/checker/interface_test.go index 0eae0fcd4e..a1c729f599 100644 --- a/runtime/tests/checker/interface_test.go +++ b/runtime/tests/checker/interface_test.go @@ -281,7 +281,7 @@ func TestCheckInterfaceUse(t *testing.T) { body = "()" } - annotationType := AsInterfaceType("Test", kind) + annotationType := "{Test}" t.Run(kind.Keyword(), func(t *testing.T) { @@ -320,7 +320,7 @@ func TestCheckInterfaceConformanceNoRequirements(t *testing.T) { body = "()" } - annotationType := AsInterfaceType("Test", compositeKind) + annotationType := "{Test}" var useCode string if compositeKind != common.CompositeKindContract { @@ -386,7 +386,7 @@ func TestCheckInvalidInterfaceConformanceIncompatibleCompositeKinds(t *testing.T secondBody = "()" } - firstKindInterfaceType := AsInterfaceType("Test", firstKind) + firstKindInterfaceType := "{Test}" // NOTE: type mismatch is only tested when both kinds are not contracts // (which can not be passed by value) @@ -455,7 +455,7 @@ func TestCheckInvalidInterfaceConformanceUndeclared(t *testing.T) { continue } - interfaceType := AsInterfaceType("Test", compositeKind) + interfaceType := "{Test}" var useCode string if compositeKind != common.CompositeKindContract { @@ -557,7 +557,7 @@ func TestCheckInterfaceFieldUse(t *testing.T) { t.Run(compositeKind.Keyword(), func(t *testing.T) { - interfaceType := AsInterfaceType("Test", compositeKind) + interfaceType := "{Test}" _, err := ParseAndCheck(t, fmt.Sprintf( @@ -602,7 +602,7 @@ func TestCheckInvalidInterfaceUndeclaredFieldUse(t *testing.T) { continue } - interfaceType := AsInterfaceType("Test", compositeKind) + interfaceType := "{Test}" t.Run(compositeKind.Keyword(), func(t *testing.T) { @@ -648,7 +648,7 @@ func TestCheckInterfaceFunctionUse(t *testing.T) { if compositeKind != common.CompositeKindContract { identifier = "test" - interfaceType := AsInterfaceType("Test", compositeKind) + interfaceType := "{Test}" setupCode = fmt.Sprintf( `let test: %[1]s%[2]s %[3]s %[4]s TestImpl%[5]s`, @@ -704,7 +704,7 @@ func TestCheckInvalidInterfaceUndeclaredFunctionUse(t *testing.T) { t.Run(compositeKind.Keyword(), func(t *testing.T) { - interfaceType := AsInterfaceType("Test", compositeKind) + interfaceType := "{Test}" _, err := ParseAndCheck(t, fmt.Sprintf( @@ -3697,7 +3697,7 @@ func TestCheckInheritedInterfacesSubtyping(t *testing.T) { contract S: B {} - fun foo(a: [S]): [A] { + fun foo(a: [S]): [{A}] { return a // must be covariant } `) @@ -3716,7 +3716,7 @@ func TestCheckInheritedInterfacesSubtyping(t *testing.T) { contract S: B {} - fun foo(a: [B]): [A] { + fun foo(a: [{B}]): [{A}] { return a // must be covariant } `) diff --git a/runtime/tests/checker/resources_test.go b/runtime/tests/checker/resources_test.go index 6b589d8f6c..9f79e0f1bc 100644 --- a/runtime/tests/checker/resources_test.go +++ b/runtime/tests/checker/resources_test.go @@ -3097,7 +3097,7 @@ func testResourceNesting( innerTypeAnnotation := "T" if innerIsInterface { - innerTypeAnnotation = AsInterfaceType("T", innerCompositeKind) + innerTypeAnnotation = "{T}" } // Prepare the initializer, if needed. diff --git a/runtime/tests/checker/runtimetype_test.go b/runtime/tests/checker/runtimetype_test.go index 796b771a0c..4b5b78e605 100644 --- a/runtime/tests/checker/runtimetype_test.go +++ b/runtime/tests/checker/runtimetype_test.go @@ -429,63 +429,6 @@ func TestCheckCompositeTypeConstructor(t *testing.T) { } } -func TestCheckInterfaceTypeConstructor(t *testing.T) { - - t.Parallel() - - cases := []struct { - name string - code string - expectedError error - }{ - { - name: "R", - code: ` - let result = InterfaceType("R") - `, - expectedError: nil, - }, - { - name: "type mismatch", - code: ` - let result = InterfaceType(3) - `, - expectedError: &sema.TypeMismatchError{}, - }, - { - name: "too many args", - code: ` - let result = InterfaceType("", 3) - `, - expectedError: &sema.ExcessiveArgumentsError{}, - }, - { - name: "no args", - code: ` - let result = InterfaceType() - `, - expectedError: &sema.InsufficientArgumentsError{}, - }, - } - - for _, testCase := range cases { - t.Run(testCase.name, func(t *testing.T) { - checker, err := ParseAndCheck(t, testCase.code) - - if testCase.expectedError == nil { - require.NoError(t, err) - assert.Equal(t, - &sema.OptionalType{Type: sema.MetaType}, - RequireGlobalValue(t, checker.Elaboration, "result"), - ) - } else { - errs := RequireCheckerErrors(t, err, 1) - assert.IsType(t, testCase.expectedError, errs[0]) - } - }) - } -} - func TestCheckFunctionTypeConstructor(t *testing.T) { t.Parallel() diff --git a/runtime/tests/interpreter/condition_test.go b/runtime/tests/interpreter/condition_test.go index 6f69055e89..841b1f0c36 100644 --- a/runtime/tests/interpreter/condition_test.go +++ b/runtime/tests/interpreter/condition_test.go @@ -620,7 +620,7 @@ func TestInterpretInterfaceFunctionUseWithPreCondition(t *testing.T) { if compositeKind == common.CompositeKindContract { identifier = "TestImpl" } else { - interfaceType := AsInterfaceType("Test", compositeKind) + interfaceType := "{Test}" setupCode = fmt.Sprintf( `let test: %[1]s%[2]s %[3]s %[4]s TestImpl%[5]s`, @@ -839,7 +839,7 @@ func TestInterpretInitializerWithInterfacePreCondition(t *testing.T) { } ` } else { - interfaceType := AsInterfaceType("Test", compositeKind) + interfaceType := "{Test}" testFunction = fmt.Sprintf( diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index 1b2db69e1d..7dae235c04 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -3973,7 +3973,7 @@ func TestInterpretInterfaceConformanceNoRequirements(t *testing.T) { continue } - interfaceType := AsInterfaceType("Test", compositeKind) + interfaceType := "{Test}" t.Run(compositeKind.Keyword(), func(t *testing.T) { @@ -4017,7 +4017,7 @@ func TestInterpretInterfaceFieldUse(t *testing.T) { if compositeKind == common.CompositeKindContract { identifier = "TestImpl" } else { - interfaceType := AsInterfaceType("Test", compositeKind) + interfaceType := "{Test}" setupCode = fmt.Sprintf( `access(all) let test: %[1]s%[2]s %[3]s %[4]s TestImpl%[5]s`, @@ -4097,7 +4097,7 @@ func TestInterpretInterfaceFunctionUse(t *testing.T) { if compositeKind == common.CompositeKindContract { identifier = "TestImpl" } else { - interfaceType := AsInterfaceType("Test", compositeKind) + interfaceType := "{Test}" setupCode = fmt.Sprintf( `access(all) let test: %[1]s %[2]s %[3]s %[4]s TestImpl%[5]s`, diff --git a/runtime/tests/interpreter/runtimetype_test.go b/runtime/tests/interpreter/runtimetype_test.go index 1950c6eef5..07c13038ef 100644 --- a/runtime/tests/interpreter/runtimetype_test.go +++ b/runtime/tests/interpreter/runtimetype_test.go @@ -354,46 +354,6 @@ func TestInterpretCompositeType(t *testing.T) { ) } -func TestInterpretInterfaceType(t *testing.T) { - - t.Parallel() - - inter := parseCheckAndInterpret(t, ` - resource interface R {} - struct interface S {} - struct B {} - - let a = InterfaceType("S.test.R")! - let b = InterfaceType("S.test.S")! - let c = InterfaceType("S.test.A") - let d = InterfaceType("S.test.B") - `) - - assert.Equal(t, - interpreter.TypeValue{ - Type: interpreter.NewInterfaceStaticTypeComputeTypeID(nil, utils.TestLocation, "R"), - }, - inter.Globals.Get("a").GetValue(), - ) - - assert.Equal(t, - interpreter.TypeValue{ - Type: interpreter.NewInterfaceStaticTypeComputeTypeID(nil, utils.TestLocation, "S"), - }, - inter.Globals.Get("b").GetValue(), - ) - - assert.Equal(t, - interpreter.Nil, - inter.Globals.Get("c").GetValue(), - ) - - assert.Equal(t, - interpreter.Nil, - inter.Globals.Get("d").GetValue(), - ) -} - func TestInterpretFunctionType(t *testing.T) { t.Parallel() diff --git a/runtime/tests/utils/utils.go b/runtime/tests/utils/utils.go index eb581635a8..a8a26f26dc 100644 --- a/runtime/tests/utils/utils.go +++ b/runtime/tests/utils/utils.go @@ -79,16 +79,6 @@ func AssertEqualWithDiff(t *testing.T, expected, actual any) { s.String(), ) } - -} - -func AsInterfaceType(name string, kind common.CompositeKind) string { - switch kind { - case common.CompositeKindResource, common.CompositeKindStructure: - return fmt.Sprintf("{%s}", name) - default: - return name - } } func DeploymentTransaction(name string, contract []byte) []byte {