Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify the requirement for all interface types to have to occur in intersection types #3090

Merged
52 changes: 50 additions & 2 deletions migrations/statictypes/account_type_migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ func TestAccountTypeInTypeValueMigration(t *testing.T) {
interpreter.NewVariableSizedStaticType(nil, authAccountReferenceType),
),
},
"interface": {
"non_intersection_interface": {
storedType: interpreter.NewInterfaceStaticType(
nil,
nil,
Expand All @@ -325,6 +325,38 @@ func TestAccountTypeInTypeValueMigration(t *testing.T) {
fooBarQualifiedIdentifier,
),
),
expectedType: interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
interpreter.NewInterfaceStaticType(
nil,
nil,
fooBarQualifiedIdentifier,
common.NewTypeIDFromQualifiedName(
nil,
fooAddressLocation,
fooBarQualifiedIdentifier,
),
),
},
),
},
"intersection_interface": {
storedType: interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
interpreter.NewInterfaceStaticType(
nil,
nil,
fooBarQualifiedIdentifier,
common.NewTypeIDFromQualifiedName(
nil,
fooAddressLocation,
fooBarQualifiedIdentifier,
),
),
},
),
expectedType: nil,
},
"composite": {
Expand Down Expand Up @@ -437,7 +469,23 @@ func TestAccountTypeInTypeValueMigration(t *testing.T) {
typeValue := value.(interpreter.TypeValue)
if actualIntersectionType, ok := typeValue.Type.(*interpreter.IntersectionStaticType); ok {
expectedIntersectionType := testCase.expectedType.(*interpreter.IntersectionStaticType)
assert.True(t, actualIntersectionType.LegacyType.Equal(expectedIntersectionType.LegacyType))

if actualIntersectionType.LegacyType != nil {
assert.True(t,
actualIntersectionType.LegacyType.
Equal(expectedIntersectionType.LegacyType),
)
} else if expectedIntersectionType.LegacyType != nil {
assert.True(t,
expectedIntersectionType.LegacyType.
Equal(actualIntersectionType.LegacyType),
)
} else {
assert.Equal(t,
expectedIntersectionType.LegacyType,
actualIntersectionType.LegacyType,
)
}
}
} else {
expectedValue = interpreter.NewTypeValue(nil, testCase.storedType)
Expand Down
36 changes: 35 additions & 1 deletion migrations/statictypes/intersection_type_migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func TestIntersectionTypeMigration(t *testing.T) {
),
},
// interface
"interface": {
"non_intersection_interface": {
storedType: interpreter.NewInterfaceStaticType(
nil,
nil,
Expand All @@ -343,6 +343,39 @@ func TestIntersectionTypeMigration(t *testing.T) {
fooBarQualifiedIdentifier,
),
),
expectedType: interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
interpreter.NewInterfaceStaticType(
nil,
nil,
"Foo.Bar",
common.NewTypeIDFromQualifiedName(
nil,
fooAddressLocation,
fooBarQualifiedIdentifier,
),
),
},
),
},
"intersection_interface": {
storedType: interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
interpreter.NewInterfaceStaticType(
nil,
nil,
"Foo.Bar",
common.NewTypeIDFromQualifiedName(
nil,
fooAddressLocation,
fooBarQualifiedIdentifier,
),
),
},
),
expectedType: nil,
},
// composite
"composite": {
Expand All @@ -356,6 +389,7 @@ func TestIntersectionTypeMigration(t *testing.T) {
fooBarQualifiedIdentifier,
),
),
expectedType: nil,
},
}

Expand Down
64 changes: 44 additions & 20 deletions migrations/statictypes/statictype_migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,21 @@ func (m *StaticTypeMigration) Migrate(
) (newValue interpreter.Value, err error) {
switch value := value.(type) {
case interpreter.TypeValue:
convertedType := m.maybeConvertStaticType(value.Type)
convertedType := m.maybeConvertStaticType(value.Type, nil)
if convertedType == nil {
return
}
return interpreter.NewTypeValue(nil, convertedType), nil

case *interpreter.CapabilityValue:
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType)
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil)
if convertedBorrowType == nil {
return
}
return interpreter.NewUnmeteredCapabilityValue(value.ID, value.Address, convertedBorrowType), nil

case *interpreter.PathCapabilityValue: //nolint:staticcheck
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType)
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil)
if convertedBorrowType == nil {
return
}
Expand All @@ -91,7 +91,7 @@ func (m *StaticTypeMigration) Migrate(
}, nil

case interpreter.PathLinkValue: //nolint:staticcheck
convertedBorrowType := m.maybeConvertStaticType(value.Type)
convertedBorrowType := m.maybeConvertStaticType(value.Type, nil)
if convertedBorrowType == nil {
return
}
Expand All @@ -101,15 +101,15 @@ func (m *StaticTypeMigration) Migrate(
}, nil

case *interpreter.AccountCapabilityControllerValue:
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType)
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil)
if convertedBorrowType == nil {
return
}
borrowType := convertedBorrowType.(*interpreter.ReferenceStaticType)
return interpreter.NewUnmeteredAccountCapabilityControllerValue(borrowType, value.CapabilityID), nil

case *interpreter.StorageCapabilityControllerValue:
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType)
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil)
if convertedBorrowType == nil {
return
}
Expand All @@ -124,23 +124,23 @@ func (m *StaticTypeMigration) Migrate(
return
}

func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.StaticType) interpreter.StaticType {
func (m *StaticTypeMigration) maybeConvertStaticType(staticType, parentType interpreter.StaticType) interpreter.StaticType {
switch staticType := staticType.(type) {
case *interpreter.ConstantSizedStaticType:
convertedType := m.maybeConvertStaticType(staticType.Type)
convertedType := m.maybeConvertStaticType(staticType.Type, staticType)
if convertedType != nil {
return interpreter.NewConstantSizedStaticType(nil, convertedType, staticType.Size)
}

case *interpreter.VariableSizedStaticType:
convertedType := m.maybeConvertStaticType(staticType.Type)
convertedType := m.maybeConvertStaticType(staticType.Type, staticType)
if convertedType != nil {
return interpreter.NewVariableSizedStaticType(nil, convertedType)
}

case *interpreter.DictionaryStaticType:
convertedKeyType := m.maybeConvertStaticType(staticType.KeyType)
convertedValueType := m.maybeConvertStaticType(staticType.ValueType)
convertedKeyType := m.maybeConvertStaticType(staticType.KeyType, staticType)
convertedValueType := m.maybeConvertStaticType(staticType.ValueType, staticType)
if convertedKeyType != nil && convertedValueType != nil {
return interpreter.NewDictionaryStaticType(nil, convertedKeyType, convertedValueType)
}
Expand All @@ -152,7 +152,7 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.Stat
}

case *interpreter.CapabilityStaticType:
convertedBorrowType := m.maybeConvertStaticType(staticType.BorrowType)
convertedBorrowType := m.maybeConvertStaticType(staticType.BorrowType, staticType)
if convertedBorrowType != nil {
return interpreter.NewCapabilityStaticType(nil, convertedBorrowType)
}
Expand All @@ -164,7 +164,7 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.Stat
var convertedInterfaceType bool

for _, interfaceStaticType := range staticType.Types {
convertedType := m.maybeConvertStaticType(interfaceStaticType)
convertedType := m.maybeConvertStaticType(interfaceStaticType, staticType)

// lazily allocate the slice
if convertedInterfaceTypes == nil {
Expand Down Expand Up @@ -194,7 +194,7 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.Stat
legacyType := staticType.LegacyType
var convertedLegacyType interpreter.StaticType
if legacyType != nil {
convertedLegacyType = m.maybeConvertStaticType(legacyType)
convertedLegacyType = m.maybeConvertStaticType(legacyType, staticType)
}

// If the interface set has at least two items,
Expand All @@ -214,14 +214,14 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.Stat
}

case *interpreter.OptionalStaticType:
convertedInnerType := m.maybeConvertStaticType(staticType.Type)
convertedInnerType := m.maybeConvertStaticType(staticType.Type, staticType)
if convertedInnerType != nil {
return interpreter.NewOptionalStaticType(nil, convertedInnerType)
}

case *interpreter.ReferenceStaticType:
// TODO: Reference of references must not be allowed?
convertedReferencedType := m.maybeConvertStaticType(staticType.ReferencedType)
convertedReferencedType := m.maybeConvertStaticType(staticType.ReferencedType, staticType)
if convertedReferencedType != nil {
switch convertedReferencedType {

Expand All @@ -232,7 +232,11 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.Stat
return convertedReferencedType

default:
return interpreter.NewReferenceStaticType(nil, staticType.Authorization, convertedReferencedType)
return interpreter.NewReferenceStaticType(
nil,
staticType.Authorization,
convertedReferencedType,
)
}
}

Expand All @@ -246,16 +250,36 @@ func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.Stat
}

case *interpreter.InterfaceStaticType:
var convertedType interpreter.StaticType
interfaceTypeConverter := m.interfaceTypeConverter
if interfaceTypeConverter != nil {
return interfaceTypeConverter(staticType)
convertedType = interfaceTypeConverter(staticType)
}

// Interface types need to be placed in intersection types
if _, ok := parentType.(*interpreter.IntersectionStaticType); !ok {
if convertedType == nil {
convertedType = interpreter.NewIntersectionStaticType(
nil, []*interpreter.InterfaceStaticType{
staticType,
},
)
} else if convertedInterfaceType, ok := convertedType.(*interpreter.InterfaceStaticType); ok {
SupunS marked this conversation as resolved.
Show resolved Hide resolved
convertedType = interpreter.NewIntersectionStaticType(
nil, []*interpreter.InterfaceStaticType{
convertedInterfaceType,
},
)
}
}

return convertedType

case dummyStaticType:
// This is for testing the migration.
// i.e: wrapper was only to make it possible to use as a dictionary-key.
// i.e: the dummyStaticType wrapper was only introduced to make it possible to use the type as a dictionary key.
// Ignore the wrapper, and continue with the inner type.
return m.maybeConvertStaticType(staticType.PrimitiveStaticType)
return m.maybeConvertStaticType(staticType.PrimitiveStaticType, staticType)

case interpreter.PrimitiveStaticType:
// Is it safe to do so?
Expand Down