From e2f9ed3415d7a6237f01f2fed25a2a84bf58cefa Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 9 Apr 2024 12:39:39 -0400 Subject: [PATCH 1/5] cache results of entitlement type conversion --- migrations/entitlements/migration.go | 113 ++++++++++++++-------- migrations/entitlements/migration_test.go | 18 ++-- runtime/sema/type.go | 2 +- types_test.go | 2 +- 4 files changed, 88 insertions(+), 47 deletions(-) diff --git a/migrations/entitlements/migration.go b/migrations/entitlements/migration.go index fc492a1e57..92696a08f8 100644 --- a/migrations/entitlements/migration.go +++ b/migrations/entitlements/migration.go @@ -23,18 +23,23 @@ import ( "github.com/onflow/cadence/migrations" "github.com/onflow/cadence/migrations/statictypes" + "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/interpreter" "github.com/onflow/cadence/runtime/sema" ) type EntitlementsMigration struct { - Interpreter *interpreter.Interpreter + Interpreter *interpreter.Interpreter + migratedTypeCache map[common.TypeID]interpreter.StaticType } var _ migrations.ValueMigration = EntitlementsMigration{} func NewEntitlementsMigration(inter *interpreter.Interpreter) EntitlementsMigration { - return EntitlementsMigration{Interpreter: inter} + return EntitlementsMigration{ + Interpreter: inter, + migratedTypeCache: map[common.TypeID]interpreter.StaticType{}, + } } func (EntitlementsMigration) Name() string { @@ -58,11 +63,11 @@ func (EntitlementsMigration) Domains() map[string]struct{} { // // TODO: functions? func ConvertToEntitledType( - inter *interpreter.Interpreter, + mig EntitlementsMigration, staticType interpreter.StaticType, ) ( - interpreter.StaticType, - error, + resultType interpreter.StaticType, + conversionErr error, ) { if staticType == nil { return nil, nil @@ -72,14 +77,30 @@ func ConvertToEntitledType( return nil, fmt.Errorf("cannot migrate deprecated type: %s", staticType) } + inter := mig.Interpreter + migratedTypeCache := mig.migratedTypeCache + + staticTypeID := staticType.ID() + + if migratedType, exists := migratedTypeCache[staticTypeID]; exists { + return migratedType, nil + } + + defer func() { + if resultType != nil && conversionErr == nil { + migratedTypeCache[staticTypeID] = resultType + } + }() + switch t := staticType.(type) { case *interpreter.ReferenceStaticType: referencedType := t.ReferencedType - convertedReferencedType, err := ConvertToEntitledType(inter, referencedType) + convertedReferencedType, err := ConvertToEntitledType(mig, referencedType) if err != nil { - return nil, err + conversionErr = err + return } var returnNew bool @@ -143,100 +164,116 @@ func ConvertToEntitledType( } if returnNew { - return interpreter.NewReferenceStaticType(nil, auth, referencedType), nil + resultType = interpreter.NewReferenceStaticType(nil, auth, referencedType) + return } case *interpreter.CapabilityStaticType: - convertedBorrowType, err := ConvertToEntitledType(inter, t.BorrowType) + convertedBorrowType, err := ConvertToEntitledType(mig, t.BorrowType) if err != nil { - return nil, err + conversionErr = err + return } if convertedBorrowType != nil { - return interpreter.NewCapabilityStaticType(nil, convertedBorrowType), nil + resultType = interpreter.NewCapabilityStaticType(nil, convertedBorrowType) + return } case *interpreter.VariableSizedStaticType: elementType := t.Type - convertedElementType, err := ConvertToEntitledType(inter, elementType) + convertedElementType, err := ConvertToEntitledType(mig, elementType) if err != nil { - return nil, err + conversionErr = err + return } if convertedElementType != nil { - return interpreter.NewVariableSizedStaticType(nil, convertedElementType), nil + resultType = interpreter.NewVariableSizedStaticType(nil, convertedElementType) + return } case *interpreter.ConstantSizedStaticType: elementType := t.Type - convertedElementType, err := ConvertToEntitledType(inter, elementType) + convertedElementType, err := ConvertToEntitledType(mig, elementType) if err != nil { - return nil, err + conversionErr = err + return } if convertedElementType != nil { - return interpreter.NewConstantSizedStaticType(nil, convertedElementType, t.Size), nil + resultType = interpreter.NewConstantSizedStaticType(nil, convertedElementType, t.Size) + return } case *interpreter.DictionaryStaticType: keyType := t.KeyType - convertedKeyType, err := ConvertToEntitledType(inter, keyType) + convertedKeyType, err := ConvertToEntitledType(mig, keyType) if err != nil { - return nil, err + conversionErr = err + return } valueType := t.ValueType - convertedValueType, err := ConvertToEntitledType(inter, valueType) + convertedValueType, err := ConvertToEntitledType(mig, valueType) if err != nil { - return nil, err + conversionErr = err + return } if convertedKeyType != nil { if convertedValueType != nil { - return interpreter.NewDictionaryStaticType(nil, convertedKeyType, convertedValueType), nil + resultType = interpreter.NewDictionaryStaticType(nil, convertedKeyType, convertedValueType) + return } else { - return interpreter.NewDictionaryStaticType(nil, convertedKeyType, valueType), nil + resultType = interpreter.NewDictionaryStaticType(nil, convertedKeyType, valueType) + return } } else if convertedValueType != nil { - return interpreter.NewDictionaryStaticType(nil, keyType, convertedValueType), nil + resultType = interpreter.NewDictionaryStaticType(nil, keyType, convertedValueType) + return } case *interpreter.OptionalStaticType: innerType := t.Type - convertedInnerType, err := ConvertToEntitledType(inter, innerType) + convertedInnerType, err := ConvertToEntitledType(mig, innerType) if err != nil { - return nil, err + conversionErr = err + return } if convertedInnerType != nil { - return interpreter.NewOptionalStaticType(nil, convertedInnerType), nil + resultType = interpreter.NewOptionalStaticType(nil, convertedInnerType) + return } } - return nil, nil + return } // ConvertValueToEntitlements converts the input value into a version compatible with the new entitlements feature, // with the same members/operations accessible on any references as would have been accessible in the past. func ConvertValueToEntitlements( - inter *interpreter.Interpreter, + mig EntitlementsMigration, v interpreter.Value, ) ( interpreter.Value, error, ) { + inter := mig.Interpreter + switch v := v.(type) { case *interpreter.ArrayValue: elementType := v.Type - entitledElementType, err := ConvertToEntitledType(inter, elementType) + entitledElementType, err := ConvertToEntitledType(mig, elementType) if err != nil { return nil, err } @@ -252,7 +289,7 @@ func ConvertValueToEntitlements( case *interpreter.DictionaryValue: elementType := v.Type - entitledElementType, err := ConvertToEntitledType(inter, elementType) + entitledElementType, err := ConvertToEntitledType(mig, elementType) if err != nil { return nil, err } @@ -268,7 +305,7 @@ func ConvertValueToEntitlements( case *interpreter.IDCapabilityValue: borrowType := v.BorrowType - entitledBorrowType, err := ConvertToEntitledType(inter, borrowType) + entitledBorrowType, err := ConvertToEntitledType(mig, borrowType) if err != nil { return nil, err } @@ -285,7 +322,7 @@ func ConvertValueToEntitlements( case *interpreter.PathCapabilityValue: //nolint:staticcheck borrowType := v.BorrowType - entitledBorrowType, err := ConvertToEntitledType(inter, borrowType) + entitledBorrowType, err := ConvertToEntitledType(mig, borrowType) if err != nil { return nil, err } @@ -301,7 +338,7 @@ func ConvertValueToEntitlements( case interpreter.TypeValue: ty := v.Type - entitledType, err := ConvertToEntitledType(inter, ty) + entitledType, err := ConvertToEntitledType(mig, ty) if err != nil { return nil, err } @@ -313,7 +350,7 @@ func ConvertValueToEntitlements( case *interpreter.AccountCapabilityControllerValue: borrowType := v.BorrowType - entitledBorrowType, err := ConvertToEntitledType(inter, borrowType) + entitledBorrowType, err := ConvertToEntitledType(mig, borrowType) if err != nil { return nil, err } @@ -329,7 +366,7 @@ func ConvertValueToEntitlements( case *interpreter.StorageCapabilityControllerValue: borrowType := v.BorrowType - entitledBorrowType, err := ConvertToEntitledType(inter, borrowType) + entitledBorrowType, err := ConvertToEntitledType(mig, borrowType) if err != nil { return nil, err } @@ -346,7 +383,7 @@ func ConvertValueToEntitlements( case interpreter.PathLinkValue: //nolint:staticcheck borrowType := v.Type - entitledBorrowType, err := ConvertToEntitledType(inter, borrowType) + entitledBorrowType, err := ConvertToEntitledType(mig, borrowType) if err != nil { return nil, err } @@ -371,7 +408,7 @@ func (mig EntitlementsMigration) Migrate( interpreter.Value, error, ) { - return ConvertValueToEntitlements(mig.Interpreter, value) + return ConvertValueToEntitlements(mig, value) } func (mig EntitlementsMigration) CanSkip(valueType interpreter.StaticType) bool { diff --git a/migrations/entitlements/migration_test.go b/migrations/entitlements/migration_test.go index 9650a14f7b..872af2fce3 100644 --- a/migrations/entitlements/migration_test.go +++ b/migrations/entitlements/migration_test.go @@ -48,6 +48,7 @@ func TestConvertToEntitledType(t *testing.T) { t.Parallel() inter := NewTestInterpreter(t) + migration := NewEntitlementsMigration(inter) elaboration := sema.NewElaboration(nil) @@ -457,7 +458,7 @@ func TestConvertToEntitledType(t *testing.T) { Name: "int", }, { - Input: sema.NewReferenceType(nil, sema.UnauthorizedAccess, &sema.FunctionType{}), + Input: sema.NewReferenceType(nil, sema.UnauthorizedAccess, &sema.FunctionType{ReturnTypeAnnotation: sema.NewTypeAnnotation(sema.IntType)}), Output: nil, Name: "function", }, @@ -645,7 +646,7 @@ func TestConvertToEntitledType(t *testing.T) { t.Run(test.Name, func(t *testing.T) { inputStaticType := interpreter.ConvertSemaToStaticType(nil, test.Input) - convertedType, _ := ConvertToEntitledType(inter, inputStaticType) + convertedType, _ := ConvertToEntitledType(migration, inputStaticType) expectedType := interpreter.ConvertSemaToStaticType(nil, test.Output) @@ -692,7 +693,8 @@ func (m testEntitlementsMigration) Migrate( interpreter.Value, error, ) { - return ConvertValueToEntitlements(m.inter, value) + mig := NewEntitlementsMigration(m.inter) + return ConvertValueToEntitlements(mig, value) } func (m testEntitlementsMigration) CanSkip(_ interpreter.StaticType) bool { @@ -1572,7 +1574,7 @@ func TestMigrateSimpleContract(t *testing.T) { func TestNilTypeValue(t *testing.T) { t.Parallel() - result, err := ConvertValueToEntitlements(nil, interpreter.NewTypeValue(nil, nil)) + result, err := ConvertValueToEntitlements(NewEntitlementsMigration(nil), interpreter.NewTypeValue(nil, nil)) require.NoError(t, err) require.Nil(t, result) } @@ -1581,7 +1583,7 @@ func TestNilPathCapabilityValue(t *testing.T) { t.Parallel() result, err := ConvertValueToEntitlements( - NewTestInterpreter(t), + NewEntitlementsMigration(NewTestInterpreter(t)), &interpreter.PathCapabilityValue{ //nolint:staticcheck Address: interpreter.NewAddressValue(nil, common.MustBytesToAddress([]byte{0x1})), Path: interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "test"), @@ -2787,6 +2789,7 @@ func TestConvertDeprecatedStaticTypes(t *testing.T) { t.Parallel() inter := NewTestInterpreter(t) + migration := NewEntitlementsMigration(inter) value := interpreter.NewUnmeteredCapabilityValue( 1, interpreter.AddressValue(common.ZeroAddress), @@ -2797,7 +2800,7 @@ func TestConvertDeprecatedStaticTypes(t *testing.T) { ), ) - result, err := ConvertValueToEntitlements(inter, value) + result, err := ConvertValueToEntitlements(migration, value) require.Error(t, err) assert.ErrorContains(t, err, "cannot migrate deprecated type") require.Nil(t, result) @@ -2823,6 +2826,7 @@ func TestConvertMigratedAccountTypes(t *testing.T) { t.Parallel() inter := NewTestInterpreter(t) + migration := NewEntitlementsMigration(inter) value := interpreter.NewUnmeteredCapabilityValue( 1, interpreter.AddressValue(common.ZeroAddress), @@ -2843,7 +2847,7 @@ func TestConvertMigratedAccountTypes(t *testing.T) { require.NoError(t, err) require.NotNil(t, newValue) - result, err := ConvertValueToEntitlements(inter, newValue) + result, err := ConvertValueToEntitlements(migration, newValue) require.NoError(t, err) require.Nilf(t, result, "expected no migration, but got %s", result) }) diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 1cbdad1d5f..6320cb6df8 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -693,7 +693,7 @@ func (t *OptionalType) QualifiedString() string { } func FormatOptionalTypeID[T ~string](elementTypeID T) T { - return T(fmt.Sprintf("%s?", elementTypeID)) + return T(fmt.Sprintf("(%s)?", elementTypeID)) } func (t *OptionalType) ID() TypeID { diff --git a/types_test.go b/types_test.go index 395134f55f..d617f13b2a 100644 --- a/types_test.go +++ b/types_test.go @@ -55,7 +55,7 @@ func TestType_ID(t *testing.T) { &OptionalType{ Type: StringType, }, - "String?", + "(String)?", }, { &VariableSizedArrayType{ From 3a24334f143eb3fac7dc110ecebf2112226408a6 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 9 Apr 2024 14:03:54 -0400 Subject: [PATCH 2/5] add function to create migration with cache --- migrations/entitlements/migration.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/migrations/entitlements/migration.go b/migrations/entitlements/migration.go index 92696a08f8..3ad22e777b 100644 --- a/migrations/entitlements/migration.go +++ b/migrations/entitlements/migration.go @@ -42,6 +42,16 @@ func NewEntitlementsMigration(inter *interpreter.Interpreter) EntitlementsMigrat } } +func NewEntitlementsMigrationWithCache( + inter *interpreter.Interpreter, + migratedTypeCache map[common.TypeID]interpreter.StaticType, +) EntitlementsMigration { + return EntitlementsMigration{ + Interpreter: inter, + migratedTypeCache: migratedTypeCache, + } +} + func (EntitlementsMigration) Name() string { return "EntitlementsMigration" } From c9e18b7040bbda20c2376cc43f37d607623b7fd4 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 9 Apr 2024 14:38:50 -0400 Subject: [PATCH 3/5] swap to sync.map --- migrations/entitlements/migration.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/migrations/entitlements/migration.go b/migrations/entitlements/migration.go index 3ad22e777b..3e86562ce3 100644 --- a/migrations/entitlements/migration.go +++ b/migrations/entitlements/migration.go @@ -20,17 +20,21 @@ package entitlements import ( "fmt" + "sync" "github.com/onflow/cadence/migrations" "github.com/onflow/cadence/migrations/statictypes" - "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/interpreter" "github.com/onflow/cadence/runtime/sema" ) +type StaticTypeCache struct { + m sync.Map +} + type EntitlementsMigration struct { Interpreter *interpreter.Interpreter - migratedTypeCache map[common.TypeID]interpreter.StaticType + migratedTypeCache *StaticTypeCache } var _ migrations.ValueMigration = EntitlementsMigration{} @@ -38,13 +42,13 @@ var _ migrations.ValueMigration = EntitlementsMigration{} func NewEntitlementsMigration(inter *interpreter.Interpreter) EntitlementsMigration { return EntitlementsMigration{ Interpreter: inter, - migratedTypeCache: map[common.TypeID]interpreter.StaticType{}, + migratedTypeCache: &StaticTypeCache{m: sync.Map{}}, } } func NewEntitlementsMigrationWithCache( inter *interpreter.Interpreter, - migratedTypeCache map[common.TypeID]interpreter.StaticType, + migratedTypeCache *StaticTypeCache, ) EntitlementsMigration { return EntitlementsMigration{ Interpreter: inter, @@ -92,13 +96,13 @@ func ConvertToEntitledType( staticTypeID := staticType.ID() - if migratedType, exists := migratedTypeCache[staticTypeID]; exists { - return migratedType, nil + if migratedType, exists := migratedTypeCache.m.Load(staticTypeID); exists { + return migratedType.(interpreter.StaticType), nil } defer func() { if resultType != nil && conversionErr == nil { - migratedTypeCache[staticTypeID] = resultType + migratedTypeCache.m.Store(staticTypeID, resultType) } }() From f934ec4dec3bb1055c432e2991f3ccedce0556db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 24 Apr 2024 10:24:46 -0700 Subject: [PATCH 4/5] refactor and improve static type cache --- migrations/cache.go | 41 ++++++++++++++++++++++++++++ migrations/entitlements/migration.go | 20 ++++---------- 2 files changed, 47 insertions(+), 14 deletions(-) create mode 100644 migrations/cache.go diff --git a/migrations/cache.go b/migrations/cache.go new file mode 100644 index 0000000000..393df86f59 --- /dev/null +++ b/migrations/cache.go @@ -0,0 +1,41 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package migrations + +import ( + "sync" + + "github.com/onflow/cadence/runtime/interpreter" +) + +type StaticTypeCache struct { + entries sync.Map +} + +func (c *StaticTypeCache) Get(typeID interpreter.TypeID) (interpreter.StaticType, bool) { + v, ok := c.entries.Load(typeID) + if !ok { + return nil, false + } + return v.(interpreter.StaticType), true +} + +func (c *StaticTypeCache) Set(typeID interpreter.TypeID, ty interpreter.StaticType) { + c.entries.Store(typeID, ty) +} diff --git a/migrations/entitlements/migration.go b/migrations/entitlements/migration.go index 2a95425256..1657bce01d 100644 --- a/migrations/entitlements/migration.go +++ b/migrations/entitlements/migration.go @@ -20,7 +20,6 @@ package entitlements import ( "fmt" - "sync" "github.com/onflow/cadence/migrations" "github.com/onflow/cadence/migrations/statictypes" @@ -28,27 +27,20 @@ import ( "github.com/onflow/cadence/runtime/sema" ) -type StaticTypeCache struct { - m sync.Map -} - type EntitlementsMigration struct { Interpreter *interpreter.Interpreter - migratedTypeCache *StaticTypeCache + migratedTypeCache *migrations.StaticTypeCache } var _ migrations.ValueMigration = EntitlementsMigration{} func NewEntitlementsMigration(inter *interpreter.Interpreter) EntitlementsMigration { - return EntitlementsMigration{ - Interpreter: inter, - migratedTypeCache: &StaticTypeCache{m: sync.Map{}}, - } + return NewEntitlementsMigrationWithCache(inter, &migrations.StaticTypeCache{}) } func NewEntitlementsMigrationWithCache( inter *interpreter.Interpreter, - migratedTypeCache *StaticTypeCache, + migratedTypeCache *migrations.StaticTypeCache, ) EntitlementsMigration { return EntitlementsMigration{ Interpreter: inter, @@ -96,13 +88,13 @@ func ConvertToEntitledType( staticTypeID := staticType.ID() - if migratedType, exists := migratedTypeCache.m.Load(staticTypeID); exists { - return migratedType.(interpreter.StaticType), nil + if migratedType, exists := migratedTypeCache.Get(staticTypeID); exists { + return migratedType, nil } defer func() { if resultType != nil && conversionErr == nil { - migratedTypeCache.m.Store(staticTypeID, resultType) + migratedTypeCache.Set(staticTypeID, resultType) } }() From 6f4698ea3c9b2da393e9afae74e7567a86eda113 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 24 Apr 2024 10:34:11 -0700 Subject: [PATCH 5/5] make global functions methods of migration --- migrations/entitlements/migration.go | 54 ++++++++++------------- migrations/entitlements/migration_test.go | 19 ++++---- 2 files changed, 34 insertions(+), 39 deletions(-) diff --git a/migrations/entitlements/migration.go b/migrations/entitlements/migration.go index 1657bce01d..6586b92666 100644 --- a/migrations/entitlements/migration.go +++ b/migrations/entitlements/migration.go @@ -68,8 +68,7 @@ func (EntitlementsMigration) Domains() map[string]struct{} { // where `Entitlements(I)` is defined as the result of `T.SupportedEntitlements()` // // TODO: functions? -func ConvertToEntitledType( - mig EntitlementsMigration, +func (m EntitlementsMigration) ConvertToEntitledType( staticType interpreter.StaticType, ) ( resultType interpreter.StaticType, @@ -83,8 +82,8 @@ func ConvertToEntitledType( return nil, fmt.Errorf("cannot migrate deprecated type: %s", staticType) } - inter := mig.Interpreter - migratedTypeCache := mig.migratedTypeCache + inter := m.Interpreter + migratedTypeCache := m.migratedTypeCache staticTypeID := staticType.ID() @@ -103,7 +102,7 @@ func ConvertToEntitledType( referencedType := t.ReferencedType - convertedReferencedType, err := ConvertToEntitledType(mig, referencedType) + convertedReferencedType, err := m.ConvertToEntitledType(referencedType) if err != nil { conversionErr = err return @@ -169,7 +168,7 @@ func ConvertToEntitledType( } case *interpreter.CapabilityStaticType: - convertedBorrowType, err := ConvertToEntitledType(mig, t.BorrowType) + convertedBorrowType, err := m.ConvertToEntitledType(t.BorrowType) if err != nil { conversionErr = err return @@ -183,7 +182,7 @@ func ConvertToEntitledType( case *interpreter.VariableSizedStaticType: elementType := t.Type - convertedElementType, err := ConvertToEntitledType(mig, elementType) + convertedElementType, err := m.ConvertToEntitledType(elementType) if err != nil { conversionErr = err return @@ -197,7 +196,7 @@ func ConvertToEntitledType( case *interpreter.ConstantSizedStaticType: elementType := t.Type - convertedElementType, err := ConvertToEntitledType(mig, elementType) + convertedElementType, err := m.ConvertToEntitledType(elementType) if err != nil { conversionErr = err return @@ -211,7 +210,7 @@ func ConvertToEntitledType( case *interpreter.DictionaryStaticType: keyType := t.KeyType - convertedKeyType, err := ConvertToEntitledType(mig, keyType) + convertedKeyType, err := m.ConvertToEntitledType(keyType) if err != nil { conversionErr = err return @@ -219,7 +218,7 @@ func ConvertToEntitledType( valueType := t.ValueType - convertedValueType, err := ConvertToEntitledType(mig, valueType) + convertedValueType, err := m.ConvertToEntitledType(valueType) if err != nil { conversionErr = err return @@ -241,7 +240,7 @@ func ConvertToEntitledType( case *interpreter.OptionalStaticType: innerType := t.Type - convertedInnerType, err := ConvertToEntitledType(mig, innerType) + convertedInnerType, err := m.ConvertToEntitledType(innerType) if err != nil { conversionErr = err return @@ -258,22 +257,15 @@ func ConvertToEntitledType( // ConvertValueToEntitlements converts the input value into a version compatible with the new entitlements feature, // with the same members/operations accessible on any references as would have been accessible in the past. -func ConvertValueToEntitlements( - mig EntitlementsMigration, - v interpreter.Value, -) ( - interpreter.Value, - error, -) { - - inter := mig.Interpreter +func (m EntitlementsMigration) ConvertValueToEntitlements(v interpreter.Value) (interpreter.Value, error) { + inter := m.Interpreter switch v := v.(type) { case *interpreter.ArrayValue: elementType := v.Type - entitledElementType, err := ConvertToEntitledType(mig, elementType) + entitledElementType, err := m.ConvertToEntitledType(elementType) if err != nil { return nil, err } @@ -289,7 +281,7 @@ func ConvertValueToEntitlements( case *interpreter.DictionaryValue: elementType := v.Type - entitledElementType, err := ConvertToEntitledType(mig, elementType) + entitledElementType, err := m.ConvertToEntitledType(elementType) if err != nil { return nil, err } @@ -305,7 +297,7 @@ func ConvertValueToEntitlements( case *interpreter.IDCapabilityValue: borrowType := v.BorrowType - entitledBorrowType, err := ConvertToEntitledType(mig, borrowType) + entitledBorrowType, err := m.ConvertToEntitledType(borrowType) if err != nil { return nil, err } @@ -322,7 +314,7 @@ func ConvertValueToEntitlements( case *interpreter.PathCapabilityValue: //nolint:staticcheck borrowType := v.BorrowType - entitledBorrowType, err := ConvertToEntitledType(mig, borrowType) + entitledBorrowType, err := m.ConvertToEntitledType(borrowType) if err != nil { return nil, err } @@ -338,7 +330,7 @@ func ConvertValueToEntitlements( case interpreter.TypeValue: ty := v.Type - entitledType, err := ConvertToEntitledType(mig, ty) + entitledType, err := m.ConvertToEntitledType(ty) if err != nil { return nil, err } @@ -350,7 +342,7 @@ func ConvertValueToEntitlements( case *interpreter.AccountCapabilityControllerValue: borrowType := v.BorrowType - entitledBorrowType, err := ConvertToEntitledType(mig, borrowType) + entitledBorrowType, err := m.ConvertToEntitledType(borrowType) if err != nil { return nil, err } @@ -366,7 +358,7 @@ func ConvertValueToEntitlements( case *interpreter.StorageCapabilityControllerValue: borrowType := v.BorrowType - entitledBorrowType, err := ConvertToEntitledType(mig, borrowType) + entitledBorrowType, err := m.ConvertToEntitledType(borrowType) if err != nil { return nil, err } @@ -383,7 +375,7 @@ func ConvertValueToEntitlements( case interpreter.PathLinkValue: //nolint:staticcheck borrowType := v.Type - entitledBorrowType, err := ConvertToEntitledType(mig, borrowType) + entitledBorrowType, err := m.ConvertToEntitledType(borrowType) if err != nil { return nil, err } @@ -399,7 +391,7 @@ func ConvertValueToEntitlements( return nil, nil } -func (mig EntitlementsMigration) Migrate( +func (m EntitlementsMigration) Migrate( _ interpreter.StorageKey, _ interpreter.StorageMapKey, value interpreter.Value, @@ -408,9 +400,9 @@ func (mig EntitlementsMigration) Migrate( interpreter.Value, error, ) { - return ConvertValueToEntitlements(mig, value) + return m.ConvertValueToEntitlements(value) } -func (mig EntitlementsMigration) CanSkip(valueType interpreter.StaticType) bool { +func (m EntitlementsMigration) CanSkip(valueType interpreter.StaticType) bool { return statictypes.CanSkipStaticTypeMigration(valueType) } diff --git a/migrations/entitlements/migration_test.go b/migrations/entitlements/migration_test.go index 85bc4c7c2f..0a707d6d4e 100644 --- a/migrations/entitlements/migration_test.go +++ b/migrations/entitlements/migration_test.go @@ -647,7 +647,7 @@ func TestConvertToEntitledType(t *testing.T) { t.Run(test.Name, func(t *testing.T) { inputStaticType := interpreter.ConvertSemaToStaticType(nil, test.Input) - convertedType, _ := ConvertToEntitledType(migration, inputStaticType) + convertedType, _ := migration.ConvertToEntitledType(inputStaticType) expectedType := interpreter.ConvertSemaToStaticType(nil, test.Output) @@ -694,8 +694,8 @@ func (m testEntitlementsMigration) Migrate( interpreter.Value, error, ) { - mig := NewEntitlementsMigration(m.inter) - return ConvertValueToEntitlements(mig, value) + migration := NewEntitlementsMigration(m.inter) + return migration.ConvertValueToEntitlements(value) } func (m testEntitlementsMigration) CanSkip(_ interpreter.StaticType) bool { @@ -1577,7 +1577,10 @@ func TestMigrateSimpleContract(t *testing.T) { func TestNilTypeValue(t *testing.T) { t.Parallel() - result, err := ConvertValueToEntitlements(NewEntitlementsMigration(nil), interpreter.NewTypeValue(nil, nil)) + migration := NewEntitlementsMigration(nil) + result, err := migration.ConvertValueToEntitlements( + interpreter.NewTypeValue(nil, nil), + ) require.NoError(t, err) require.Nil(t, result) } @@ -1585,8 +1588,8 @@ func TestNilTypeValue(t *testing.T) { func TestNilPathCapabilityValue(t *testing.T) { t.Parallel() - result, err := ConvertValueToEntitlements( - NewEntitlementsMigration(NewTestInterpreter(t)), + migration := NewEntitlementsMigration(NewTestInterpreter(t)) + result, err := migration.ConvertValueToEntitlements( &interpreter.PathCapabilityValue{ //nolint:staticcheck Address: interpreter.NewAddressValue(nil, common.MustBytesToAddress([]byte{0x1})), Path: interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "test"), @@ -2816,7 +2819,7 @@ func TestConvertDeprecatedStaticTypes(t *testing.T) { ), ) - result, err := ConvertValueToEntitlements(migration, value) + result, err := migration.ConvertValueToEntitlements(value) require.Error(t, err) assert.ErrorContains(t, err, "cannot migrate deprecated type") require.Nil(t, result) @@ -2863,7 +2866,7 @@ func TestConvertMigratedAccountTypes(t *testing.T) { require.NoError(t, err) require.NotNil(t, newValue) - result, err := ConvertValueToEntitlements(migration, newValue) + result, err := migration.ConvertValueToEntitlements(newValue) require.NoError(t, err) require.Nilf(t, result, "expected no migration, but got %s", result) })