diff --git a/migrations/cache.go b/migrations/cache.go new file mode 100644 index 0000000000..393df86f59 --- /dev/null +++ b/migrations/cache.go @@ -0,0 +1,41 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package migrations + +import ( + "sync" + + "github.com/onflow/cadence/runtime/interpreter" +) + +type StaticTypeCache struct { + entries sync.Map +} + +func (c *StaticTypeCache) Get(typeID interpreter.TypeID) (interpreter.StaticType, bool) { + v, ok := c.entries.Load(typeID) + if !ok { + return nil, false + } + return v.(interpreter.StaticType), true +} + +func (c *StaticTypeCache) Set(typeID interpreter.TypeID, ty interpreter.StaticType) { + c.entries.Store(typeID, ty) +} diff --git a/migrations/entitlements/migration.go b/migrations/entitlements/migration.go index 3b73fecc23..6586b92666 100644 --- a/migrations/entitlements/migration.go +++ b/migrations/entitlements/migration.go @@ -28,13 +28,24 @@ import ( ) type EntitlementsMigration struct { - Interpreter *interpreter.Interpreter + Interpreter *interpreter.Interpreter + migratedTypeCache *migrations.StaticTypeCache } var _ migrations.ValueMigration = EntitlementsMigration{} func NewEntitlementsMigration(inter *interpreter.Interpreter) EntitlementsMigration { - return EntitlementsMigration{Interpreter: inter} + return NewEntitlementsMigrationWithCache(inter, &migrations.StaticTypeCache{}) +} + +func NewEntitlementsMigrationWithCache( + inter *interpreter.Interpreter, + migratedTypeCache *migrations.StaticTypeCache, +) EntitlementsMigration { + return EntitlementsMigration{ + Interpreter: inter, + migratedTypeCache: migratedTypeCache, + } } func (EntitlementsMigration) Name() string { @@ -57,12 +68,11 @@ func (EntitlementsMigration) Domains() map[string]struct{} { // where `Entitlements(I)` is defined as the result of `T.SupportedEntitlements()` // // TODO: functions? -func ConvertToEntitledType( - inter *interpreter.Interpreter, +func (m EntitlementsMigration) ConvertToEntitledType( staticType interpreter.StaticType, ) ( - interpreter.StaticType, - error, + resultType interpreter.StaticType, + conversionErr error, ) { if staticType == nil { return nil, nil @@ -72,14 +82,30 @@ func ConvertToEntitledType( return nil, fmt.Errorf("cannot migrate deprecated type: %s", staticType) } + inter := m.Interpreter + migratedTypeCache := m.migratedTypeCache + + staticTypeID := staticType.ID() + + if migratedType, exists := migratedTypeCache.Get(staticTypeID); exists { + return migratedType, nil + } + + defer func() { + if resultType != nil && conversionErr == nil { + migratedTypeCache.Set(staticTypeID, resultType) + } + }() + switch t := staticType.(type) { case *interpreter.ReferenceStaticType: referencedType := t.ReferencedType - convertedReferencedType, err := ConvertToEntitledType(inter, referencedType) + convertedReferencedType, err := m.ConvertToEntitledType(referencedType) if err != nil { - return nil, err + conversionErr = err + return } var returnNew bool @@ -137,100 +163,109 @@ func ConvertToEntitledType( } if returnNew { - return interpreter.NewReferenceStaticType(nil, auth, referencedType), nil + resultType = interpreter.NewReferenceStaticType(nil, auth, referencedType) + return } case *interpreter.CapabilityStaticType: - convertedBorrowType, err := ConvertToEntitledType(inter, t.BorrowType) + convertedBorrowType, err := m.ConvertToEntitledType(t.BorrowType) if err != nil { - return nil, err + conversionErr = err + return } if convertedBorrowType != nil { - return interpreter.NewCapabilityStaticType(nil, convertedBorrowType), nil + resultType = interpreter.NewCapabilityStaticType(nil, convertedBorrowType) + return } case *interpreter.VariableSizedStaticType: elementType := t.Type - convertedElementType, err := ConvertToEntitledType(inter, elementType) + convertedElementType, err := m.ConvertToEntitledType(elementType) if err != nil { - return nil, err + conversionErr = err + return } if convertedElementType != nil { - return interpreter.NewVariableSizedStaticType(nil, convertedElementType), nil + resultType = interpreter.NewVariableSizedStaticType(nil, convertedElementType) + return } case *interpreter.ConstantSizedStaticType: elementType := t.Type - convertedElementType, err := ConvertToEntitledType(inter, elementType) + convertedElementType, err := m.ConvertToEntitledType(elementType) if err != nil { - return nil, err + conversionErr = err + return } if convertedElementType != nil { - return interpreter.NewConstantSizedStaticType(nil, convertedElementType, t.Size), nil + resultType = interpreter.NewConstantSizedStaticType(nil, convertedElementType, t.Size) + return } case *interpreter.DictionaryStaticType: keyType := t.KeyType - convertedKeyType, err := ConvertToEntitledType(inter, keyType) + convertedKeyType, err := m.ConvertToEntitledType(keyType) if err != nil { - return nil, err + conversionErr = err + return } valueType := t.ValueType - convertedValueType, err := ConvertToEntitledType(inter, valueType) + convertedValueType, err := m.ConvertToEntitledType(valueType) if err != nil { - return nil, err + conversionErr = err + return } if convertedKeyType != nil { if convertedValueType != nil { - return interpreter.NewDictionaryStaticType(nil, convertedKeyType, convertedValueType), nil + resultType = interpreter.NewDictionaryStaticType(nil, convertedKeyType, convertedValueType) + return } else { - return interpreter.NewDictionaryStaticType(nil, convertedKeyType, valueType), nil + resultType = interpreter.NewDictionaryStaticType(nil, convertedKeyType, valueType) + return } } else if convertedValueType != nil { - return interpreter.NewDictionaryStaticType(nil, keyType, convertedValueType), nil + resultType = interpreter.NewDictionaryStaticType(nil, keyType, convertedValueType) + return } case *interpreter.OptionalStaticType: innerType := t.Type - convertedInnerType, err := ConvertToEntitledType(inter, innerType) + convertedInnerType, err := m.ConvertToEntitledType(innerType) if err != nil { - return nil, err + conversionErr = err + return } if convertedInnerType != nil { - return interpreter.NewOptionalStaticType(nil, convertedInnerType), nil + resultType = interpreter.NewOptionalStaticType(nil, convertedInnerType) + return } } - return nil, nil + return } // ConvertValueToEntitlements converts the input value into a version compatible with the new entitlements feature, // with the same members/operations accessible on any references as would have been accessible in the past. -func ConvertValueToEntitlements( - inter *interpreter.Interpreter, - v interpreter.Value, -) ( - interpreter.Value, - error, -) { +func (m EntitlementsMigration) ConvertValueToEntitlements(v interpreter.Value) (interpreter.Value, error) { + inter := m.Interpreter switch v := v.(type) { case *interpreter.ArrayValue: elementType := v.Type - entitledElementType, err := ConvertToEntitledType(inter, elementType) + entitledElementType, err := m.ConvertToEntitledType(elementType) if err != nil { return nil, err } @@ -246,7 +281,7 @@ func ConvertValueToEntitlements( case *interpreter.DictionaryValue: elementType := v.Type - entitledElementType, err := ConvertToEntitledType(inter, elementType) + entitledElementType, err := m.ConvertToEntitledType(elementType) if err != nil { return nil, err } @@ -262,7 +297,7 @@ func ConvertValueToEntitlements( case *interpreter.IDCapabilityValue: borrowType := v.BorrowType - entitledBorrowType, err := ConvertToEntitledType(inter, borrowType) + entitledBorrowType, err := m.ConvertToEntitledType(borrowType) if err != nil { return nil, err } @@ -279,7 +314,7 @@ func ConvertValueToEntitlements( case *interpreter.PathCapabilityValue: //nolint:staticcheck borrowType := v.BorrowType - entitledBorrowType, err := ConvertToEntitledType(inter, borrowType) + entitledBorrowType, err := m.ConvertToEntitledType(borrowType) if err != nil { return nil, err } @@ -295,7 +330,7 @@ func ConvertValueToEntitlements( case interpreter.TypeValue: ty := v.Type - entitledType, err := ConvertToEntitledType(inter, ty) + entitledType, err := m.ConvertToEntitledType(ty) if err != nil { return nil, err } @@ -307,7 +342,7 @@ func ConvertValueToEntitlements( case *interpreter.AccountCapabilityControllerValue: borrowType := v.BorrowType - entitledBorrowType, err := ConvertToEntitledType(inter, borrowType) + entitledBorrowType, err := m.ConvertToEntitledType(borrowType) if err != nil { return nil, err } @@ -323,7 +358,7 @@ func ConvertValueToEntitlements( case *interpreter.StorageCapabilityControllerValue: borrowType := v.BorrowType - entitledBorrowType, err := ConvertToEntitledType(inter, borrowType) + entitledBorrowType, err := m.ConvertToEntitledType(borrowType) if err != nil { return nil, err } @@ -340,7 +375,7 @@ func ConvertValueToEntitlements( case interpreter.PathLinkValue: //nolint:staticcheck borrowType := v.Type - entitledBorrowType, err := ConvertToEntitledType(inter, borrowType) + entitledBorrowType, err := m.ConvertToEntitledType(borrowType) if err != nil { return nil, err } @@ -356,7 +391,7 @@ func ConvertValueToEntitlements( return nil, nil } -func (mig EntitlementsMigration) Migrate( +func (m EntitlementsMigration) Migrate( _ interpreter.StorageKey, _ interpreter.StorageMapKey, value interpreter.Value, @@ -365,9 +400,9 @@ func (mig EntitlementsMigration) Migrate( interpreter.Value, error, ) { - return ConvertValueToEntitlements(mig.Interpreter, value) + return m.ConvertValueToEntitlements(value) } -func (mig EntitlementsMigration) CanSkip(valueType interpreter.StaticType) bool { +func (m EntitlementsMigration) CanSkip(valueType interpreter.StaticType) bool { return statictypes.CanSkipStaticTypeMigration(valueType) } diff --git a/migrations/entitlements/migration_test.go b/migrations/entitlements/migration_test.go index 154f63e677..0a707d6d4e 100644 --- a/migrations/entitlements/migration_test.go +++ b/migrations/entitlements/migration_test.go @@ -49,6 +49,7 @@ func TestConvertToEntitledType(t *testing.T) { t.Parallel() inter := NewTestInterpreter(t) + migration := NewEntitlementsMigration(inter) elaboration := sema.NewElaboration(nil) @@ -458,7 +459,7 @@ func TestConvertToEntitledType(t *testing.T) { Name: "int", }, { - Input: sema.NewReferenceType(nil, sema.UnauthorizedAccess, &sema.FunctionType{}), + Input: sema.NewReferenceType(nil, sema.UnauthorizedAccess, &sema.FunctionType{ReturnTypeAnnotation: sema.NewTypeAnnotation(sema.IntType)}), Output: nil, Name: "function", }, @@ -646,7 +647,7 @@ func TestConvertToEntitledType(t *testing.T) { t.Run(test.Name, func(t *testing.T) { inputStaticType := interpreter.ConvertSemaToStaticType(nil, test.Input) - convertedType, _ := ConvertToEntitledType(inter, inputStaticType) + convertedType, _ := migration.ConvertToEntitledType(inputStaticType) expectedType := interpreter.ConvertSemaToStaticType(nil, test.Output) @@ -693,7 +694,8 @@ func (m testEntitlementsMigration) Migrate( interpreter.Value, error, ) { - return ConvertValueToEntitlements(m.inter, value) + migration := NewEntitlementsMigration(m.inter) + return migration.ConvertValueToEntitlements(value) } func (m testEntitlementsMigration) CanSkip(_ interpreter.StaticType) bool { @@ -1575,7 +1577,10 @@ func TestMigrateSimpleContract(t *testing.T) { func TestNilTypeValue(t *testing.T) { t.Parallel() - result, err := ConvertValueToEntitlements(nil, interpreter.NewTypeValue(nil, nil)) + migration := NewEntitlementsMigration(nil) + result, err := migration.ConvertValueToEntitlements( + interpreter.NewTypeValue(nil, nil), + ) require.NoError(t, err) require.Nil(t, result) } @@ -1583,8 +1588,8 @@ func TestNilTypeValue(t *testing.T) { func TestNilPathCapabilityValue(t *testing.T) { t.Parallel() - result, err := ConvertValueToEntitlements( - NewTestInterpreter(t), + migration := NewEntitlementsMigration(NewTestInterpreter(t)) + result, err := migration.ConvertValueToEntitlements( &interpreter.PathCapabilityValue{ //nolint:staticcheck Address: interpreter.NewAddressValue(nil, common.MustBytesToAddress([]byte{0x1})), Path: interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "test"), @@ -2803,6 +2808,7 @@ func TestConvertDeprecatedStaticTypes(t *testing.T) { t.Parallel() inter := NewTestInterpreter(t) + migration := NewEntitlementsMigration(inter) value := interpreter.NewUnmeteredCapabilityValue( 1, interpreter.AddressValue(common.ZeroAddress), @@ -2813,7 +2819,7 @@ func TestConvertDeprecatedStaticTypes(t *testing.T) { ), ) - result, err := ConvertValueToEntitlements(inter, value) + result, err := migration.ConvertValueToEntitlements(value) require.Error(t, err) assert.ErrorContains(t, err, "cannot migrate deprecated type") require.Nil(t, result) @@ -2839,6 +2845,7 @@ func TestConvertMigratedAccountTypes(t *testing.T) { t.Parallel() inter := NewTestInterpreter(t) + migration := NewEntitlementsMigration(inter) value := interpreter.NewUnmeteredCapabilityValue( 1, interpreter.AddressValue(common.ZeroAddress), @@ -2859,7 +2866,7 @@ func TestConvertMigratedAccountTypes(t *testing.T) { require.NoError(t, err) require.NotNil(t, newValue) - result, err := ConvertValueToEntitlements(inter, newValue) + result, err := migration.ConvertValueToEntitlements(newValue) require.NoError(t, err) require.Nilf(t, result, "expected no migration, but got %s", result) })