Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify the requirement for all interface types to have to occur in intersection types #3090

Merged
52 changes: 50 additions & 2 deletions migrations/statictypes/account_type_migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ func TestAccountTypeInTypeValueMigration(t *testing.T) {
interpreter.NewVariableSizedStaticType(nil, authAccountReferenceType),
),
},
"interface": {
"non_intersection_interface": {
storedType: interpreter.NewInterfaceStaticType(
nil,
nil,
Expand All @@ -325,6 +325,38 @@ func TestAccountTypeInTypeValueMigration(t *testing.T) {
fooBarQualifiedIdentifier,
),
),
expectedType: interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
interpreter.NewInterfaceStaticType(
nil,
nil,
fooBarQualifiedIdentifier,
common.NewTypeIDFromQualifiedName(
nil,
fooAddressLocation,
fooBarQualifiedIdentifier,
),
),
},
),
},
"intersection_interface": {
storedType: interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
interpreter.NewInterfaceStaticType(
nil,
nil,
fooBarQualifiedIdentifier,
common.NewTypeIDFromQualifiedName(
nil,
fooAddressLocation,
fooBarQualifiedIdentifier,
),
),
},
),
expectedType: nil,
},
"composite": {
Expand Down Expand Up @@ -437,7 +469,23 @@ func TestAccountTypeInTypeValueMigration(t *testing.T) {
typeValue := value.(interpreter.TypeValue)
if actualIntersectionType, ok := typeValue.Type.(*interpreter.IntersectionStaticType); ok {
expectedIntersectionType := testCase.expectedType.(*interpreter.IntersectionStaticType)
assert.True(t, actualIntersectionType.LegacyType.Equal(expectedIntersectionType.LegacyType))

if actualIntersectionType.LegacyType != nil {
assert.True(t,
actualIntersectionType.LegacyType.
Equal(expectedIntersectionType.LegacyType),
)
} else if expectedIntersectionType.LegacyType != nil {
assert.True(t,
expectedIntersectionType.LegacyType.
Equal(actualIntersectionType.LegacyType),
)
} else {
assert.Equal(t,
expectedIntersectionType.LegacyType,
actualIntersectionType.LegacyType,
)
}
}
} else {
expectedValue = interpreter.NewTypeValue(nil, testCase.storedType)
Expand Down
36 changes: 35 additions & 1 deletion migrations/statictypes/intersection_type_migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func TestIntersectionTypeMigration(t *testing.T) {
),
},
// interface
"interface": {
"non_intersection_interface": {
storedType: interpreter.NewInterfaceStaticType(
nil,
nil,
Expand All @@ -343,6 +343,39 @@ func TestIntersectionTypeMigration(t *testing.T) {
fooBarQualifiedIdentifier,
),
),
expectedType: interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
interpreter.NewInterfaceStaticType(
nil,
nil,
"Foo.Bar",
common.NewTypeIDFromQualifiedName(
nil,
fooAddressLocation,
fooBarQualifiedIdentifier,
),
),
},
),
},
"intersection_interface": {
storedType: interpreter.NewIntersectionStaticType(
nil,
[]*interpreter.InterfaceStaticType{
interpreter.NewInterfaceStaticType(
nil,
nil,
"Foo.Bar",
common.NewTypeIDFromQualifiedName(
nil,
fooAddressLocation,
fooBarQualifiedIdentifier,
),
),
},
),
expectedType: nil,
},
// composite
"composite": {
Expand All @@ -356,6 +389,7 @@ func TestIntersectionTypeMigration(t *testing.T) {
fooBarQualifiedIdentifier,
),
),
expectedType: nil,
},
}

Expand Down
64 changes: 44 additions & 20 deletions migrations/statictypes/statictype_migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,21 @@
) (newValue interpreter.Value, err error) {
switch value := value.(type) {
case interpreter.TypeValue:
convertedType := m.maybeConvertStaticType(value.Type)
convertedType := m.maybeConvertStaticType(value.Type, nil)
if convertedType == nil {
return
}
return interpreter.NewTypeValue(nil, convertedType), nil

case *interpreter.CapabilityValue:
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType)
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil)
if convertedBorrowType == nil {
return
}
return interpreter.NewUnmeteredCapabilityValue(value.ID, value.Address, convertedBorrowType), nil

case *interpreter.PathCapabilityValue: //nolint:staticcheck
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType)
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil)
if convertedBorrowType == nil {
return
}
Expand All @@ -91,7 +91,7 @@
}, nil

case interpreter.PathLinkValue: //nolint:staticcheck
convertedBorrowType := m.maybeConvertStaticType(value.Type)
convertedBorrowType := m.maybeConvertStaticType(value.Type, nil)
if convertedBorrowType == nil {
return
}
Expand All @@ -101,15 +101,15 @@
}, nil

case *interpreter.AccountCapabilityControllerValue:
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType)
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil)
if convertedBorrowType == nil {
return
}
borrowType := convertedBorrowType.(*interpreter.ReferenceStaticType)
return interpreter.NewUnmeteredAccountCapabilityControllerValue(borrowType, value.CapabilityID), nil

case *interpreter.StorageCapabilityControllerValue:
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType)
convertedBorrowType := m.maybeConvertStaticType(value.BorrowType, nil)
if convertedBorrowType == nil {
return
}
Expand All @@ -124,23 +124,23 @@
return
}

func (m *StaticTypeMigration) maybeConvertStaticType(staticType interpreter.StaticType) interpreter.StaticType {
func (m *StaticTypeMigration) maybeConvertStaticType(staticType, parentType interpreter.StaticType) interpreter.StaticType {
switch staticType := staticType.(type) {
case *interpreter.ConstantSizedStaticType:
convertedType := m.maybeConvertStaticType(staticType.Type)
convertedType := m.maybeConvertStaticType(staticType.Type, staticType)
if convertedType != nil {
return interpreter.NewConstantSizedStaticType(nil, convertedType, staticType.Size)
}

case *interpreter.VariableSizedStaticType:
convertedType := m.maybeConvertStaticType(staticType.Type)
convertedType := m.maybeConvertStaticType(staticType.Type, staticType)
if convertedType != nil {
return interpreter.NewVariableSizedStaticType(nil, convertedType)
}

case *interpreter.DictionaryStaticType:
convertedKeyType := m.maybeConvertStaticType(staticType.KeyType)
convertedValueType := m.maybeConvertStaticType(staticType.ValueType)
convertedKeyType := m.maybeConvertStaticType(staticType.KeyType, staticType)
convertedValueType := m.maybeConvertStaticType(staticType.ValueType, staticType)
if convertedKeyType != nil && convertedValueType != nil {
return interpreter.NewDictionaryStaticType(nil, convertedKeyType, convertedValueType)
}
Expand All @@ -152,7 +152,7 @@
}

case *interpreter.CapabilityStaticType:
convertedBorrowType := m.maybeConvertStaticType(staticType.BorrowType)
convertedBorrowType := m.maybeConvertStaticType(staticType.BorrowType, staticType)
if convertedBorrowType != nil {
return interpreter.NewCapabilityStaticType(nil, convertedBorrowType)
}
Expand All @@ -164,7 +164,7 @@
var convertedInterfaceType bool

for _, interfaceStaticType := range staticType.Types {
convertedType := m.maybeConvertStaticType(interfaceStaticType)
convertedType := m.maybeConvertStaticType(interfaceStaticType, staticType)

// lazily allocate the slice
if convertedInterfaceTypes == nil {
Expand Down Expand Up @@ -194,7 +194,7 @@
legacyType := staticType.LegacyType
var convertedLegacyType interpreter.StaticType
if legacyType != nil {
convertedLegacyType = m.maybeConvertStaticType(legacyType)
convertedLegacyType = m.maybeConvertStaticType(legacyType, staticType)
}

// If the interface set has at least two items,
Expand All @@ -214,14 +214,14 @@
}

case *interpreter.OptionalStaticType:
convertedInnerType := m.maybeConvertStaticType(staticType.Type)
convertedInnerType := m.maybeConvertStaticType(staticType.Type, staticType)
if convertedInnerType != nil {
return interpreter.NewOptionalStaticType(nil, convertedInnerType)
}

case *interpreter.ReferenceStaticType:
// TODO: Reference of references must not be allowed?
convertedReferencedType := m.maybeConvertStaticType(staticType.ReferencedType)
convertedReferencedType := m.maybeConvertStaticType(staticType.ReferencedType, staticType)
if convertedReferencedType != nil {
switch convertedReferencedType {

Expand All @@ -232,7 +232,11 @@
return convertedReferencedType

default:
return interpreter.NewReferenceStaticType(nil, staticType.Authorization, convertedReferencedType)
return interpreter.NewReferenceStaticType(
nil,
staticType.Authorization,
convertedReferencedType,
)
}
}

Expand All @@ -246,16 +250,36 @@
}

case *interpreter.InterfaceStaticType:
var convertedType interpreter.StaticType
interfaceTypeConverter := m.interfaceTypeConverter
if interfaceTypeConverter != nil {
return interfaceTypeConverter(staticType)
convertedType = interfaceTypeConverter(staticType)
}

// Interface types need to be placed in intersection types
if _, ok := parentType.(*interpreter.IntersectionStaticType); !ok {
if convertedType == nil {
convertedType = interpreter.NewIntersectionStaticType(
nil, []*interpreter.InterfaceStaticType{
staticType,
},
)
} else if convertedInterfaceType, ok := convertedType.(*interpreter.InterfaceStaticType); ok {
SupunS marked this conversation as resolved.
Show resolved Hide resolved
convertedType = interpreter.NewIntersectionStaticType(
nil, []*interpreter.InterfaceStaticType{
convertedInterfaceType,
},
)
}

Check warning on line 273 in migrations/statictypes/statictype_migration.go

View check run for this annotation

Codecov / codecov/patch

migrations/statictypes/statictype_migration.go#L268-L273

Added lines #L268 - L273 were not covered by tests
}

return convertedType

case dummyStaticType:
// This is for testing the migration.
// i.e: wrapper was only to make it possible to use as a dictionary-key.
// i.e: the dummyStaticType wrapper was only introduced to make it possible to use the type as a dictionary key.
// Ignore the wrapper, and continue with the inner type.
return m.maybeConvertStaticType(staticType.PrimitiveStaticType)
return m.maybeConvertStaticType(staticType.PrimitiveStaticType, staticType)

case interpreter.PrimitiveStaticType:
// Is it safe to do so?
Expand Down
4 changes: 2 additions & 2 deletions runtime/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,7 @@ func TestRuntimeAuthAccountContracts(t *testing.T) {

transaction {
prepare(acc: &Account) {
let hello = acc.contracts.borrow<&HelloInterface>(name: "Hello")
let hello = acc.contracts.borrow<&{HelloInterface}>(name: "Hello")
assert(hello?.hello() == "Hello!")
}
}
Expand Down Expand Up @@ -1747,7 +1747,7 @@ func TestRuntimeAuthAccountContracts(t *testing.T) {

transaction {
prepare(acc: &Account) {
let hello = acc.contracts.borrow<&HelloInterface>(name: "Hello")
let hello = acc.contracts.borrow<&{HelloInterface}>(name: "Hello")
assert(hello == nil)
}
}
Expand Down
30 changes: 0 additions & 30 deletions runtime/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3391,15 +3391,6 @@ func init() {
),
)

defineBaseValue(
BaseActivation,
sema.InterfaceTypeFunctionName,
NewUnmeteredHostFunctionValue(
sema.InterfaceTypeFunctionType,
interfaceTypeFunction,
),
)

defineBaseValue(
BaseActivation,
sema.FunctionTypeFunctionName,
Expand Down Expand Up @@ -3536,27 +3527,6 @@ func compositeTypeFunction(invocation Invocation) Value {
)
}

func interfaceTypeFunction(invocation Invocation) Value {
typeIDValue, ok := invocation.Arguments[0].(*StringValue)
if !ok {
panic(errors.NewUnreachableError())
}
typeID := typeIDValue.Str

interfaceType, err := lookupInterface(invocation.Interpreter, typeID)
if err != nil {
return Nil
}

return NewSomeValueNonCopying(
invocation.Interpreter,
NewTypeValue(
invocation.Interpreter,
ConvertSemaToStaticType(invocation.Interpreter, interfaceType),
),
)
}

func functionTypeFunction(invocation Invocation) Value {
interpreter := invocation.Interpreter

Expand Down
Loading
Loading