Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache results of entitlement type conversion #3232

Merged
merged 6 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions migrations/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Cadence - The resource-oriented smart contract programming language
*
* Copyright Dapper Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package migrations

import (
"sync"

"github.com/onflow/cadence/runtime/interpreter"
)

type StaticTypeCache struct {
entries sync.Map
}

func (c *StaticTypeCache) Get(typeID interpreter.TypeID) (interpreter.StaticType, bool) {
v, ok := c.entries.Load(typeID)
if !ok {
return nil, false
}
return v.(interpreter.StaticType), true
}

func (c *StaticTypeCache) Set(typeID interpreter.TypeID, ty interpreter.StaticType) {
c.entries.Store(typeID, ty)
}
129 changes: 82 additions & 47 deletions migrations/entitlements/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,24 @@ import (
)

type EntitlementsMigration struct {
Interpreter *interpreter.Interpreter
Interpreter *interpreter.Interpreter
migratedTypeCache *migrations.StaticTypeCache
}

var _ migrations.ValueMigration = EntitlementsMigration{}

func NewEntitlementsMigration(inter *interpreter.Interpreter) EntitlementsMigration {
return EntitlementsMigration{Interpreter: inter}
return NewEntitlementsMigrationWithCache(inter, &migrations.StaticTypeCache{})
}

func NewEntitlementsMigrationWithCache(
inter *interpreter.Interpreter,
migratedTypeCache *migrations.StaticTypeCache,
) EntitlementsMigration {
return EntitlementsMigration{
Interpreter: inter,
migratedTypeCache: migratedTypeCache,
}
}

func (EntitlementsMigration) Name() string {
Expand All @@ -57,12 +68,11 @@ func (EntitlementsMigration) Domains() map[string]struct{} {
// where `Entitlements(I)` is defined as the result of `T.SupportedEntitlements()`
//
// TODO: functions?
func ConvertToEntitledType(
inter *interpreter.Interpreter,
func (m EntitlementsMigration) ConvertToEntitledType(
staticType interpreter.StaticType,
) (
interpreter.StaticType,
error,
resultType interpreter.StaticType,
conversionErr error,
) {
if staticType == nil {
return nil, nil
Expand All @@ -72,14 +82,30 @@ func ConvertToEntitledType(
return nil, fmt.Errorf("cannot migrate deprecated type: %s", staticType)
}

inter := m.Interpreter
migratedTypeCache := m.migratedTypeCache

staticTypeID := staticType.ID()

if migratedType, exists := migratedTypeCache.Get(staticTypeID); exists {
return migratedType, nil
}

defer func() {
if resultType != nil && conversionErr == nil {
migratedTypeCache.Set(staticTypeID, resultType)
}
}()

switch t := staticType.(type) {
case *interpreter.ReferenceStaticType:

referencedType := t.ReferencedType

convertedReferencedType, err := ConvertToEntitledType(inter, referencedType)
convertedReferencedType, err := m.ConvertToEntitledType(referencedType)
if err != nil {
return nil, err
conversionErr = err
return
}

var returnNew bool
Expand Down Expand Up @@ -137,100 +163,109 @@ func ConvertToEntitledType(
}

if returnNew {
return interpreter.NewReferenceStaticType(nil, auth, referencedType), nil
resultType = interpreter.NewReferenceStaticType(nil, auth, referencedType)
return
}

case *interpreter.CapabilityStaticType:
convertedBorrowType, err := ConvertToEntitledType(inter, t.BorrowType)
convertedBorrowType, err := m.ConvertToEntitledType(t.BorrowType)
if err != nil {
return nil, err
conversionErr = err
return
}

if convertedBorrowType != nil {
return interpreter.NewCapabilityStaticType(nil, convertedBorrowType), nil
resultType = interpreter.NewCapabilityStaticType(nil, convertedBorrowType)
return
}

case *interpreter.VariableSizedStaticType:
elementType := t.Type

convertedElementType, err := ConvertToEntitledType(inter, elementType)
convertedElementType, err := m.ConvertToEntitledType(elementType)
if err != nil {
return nil, err
conversionErr = err
return
}

if convertedElementType != nil {
return interpreter.NewVariableSizedStaticType(nil, convertedElementType), nil
resultType = interpreter.NewVariableSizedStaticType(nil, convertedElementType)
return
}

case *interpreter.ConstantSizedStaticType:
elementType := t.Type

convertedElementType, err := ConvertToEntitledType(inter, elementType)
convertedElementType, err := m.ConvertToEntitledType(elementType)
if err != nil {
return nil, err
conversionErr = err
return
}

if convertedElementType != nil {
return interpreter.NewConstantSizedStaticType(nil, convertedElementType, t.Size), nil
resultType = interpreter.NewConstantSizedStaticType(nil, convertedElementType, t.Size)
return
}

case *interpreter.DictionaryStaticType:
keyType := t.KeyType

convertedKeyType, err := ConvertToEntitledType(inter, keyType)
convertedKeyType, err := m.ConvertToEntitledType(keyType)
if err != nil {
return nil, err
conversionErr = err
return
}

valueType := t.ValueType

convertedValueType, err := ConvertToEntitledType(inter, valueType)
convertedValueType, err := m.ConvertToEntitledType(valueType)
if err != nil {
return nil, err
conversionErr = err
return
}

if convertedKeyType != nil {
if convertedValueType != nil {
return interpreter.NewDictionaryStaticType(nil, convertedKeyType, convertedValueType), nil
resultType = interpreter.NewDictionaryStaticType(nil, convertedKeyType, convertedValueType)
return
} else {
return interpreter.NewDictionaryStaticType(nil, convertedKeyType, valueType), nil
resultType = interpreter.NewDictionaryStaticType(nil, convertedKeyType, valueType)
return
}
} else if convertedValueType != nil {
return interpreter.NewDictionaryStaticType(nil, keyType, convertedValueType), nil
resultType = interpreter.NewDictionaryStaticType(nil, keyType, convertedValueType)
return
}

case *interpreter.OptionalStaticType:
innerType := t.Type

convertedInnerType, err := ConvertToEntitledType(inter, innerType)
convertedInnerType, err := m.ConvertToEntitledType(innerType)
if err != nil {
return nil, err
conversionErr = err
return
}

if convertedInnerType != nil {
return interpreter.NewOptionalStaticType(nil, convertedInnerType), nil
resultType = interpreter.NewOptionalStaticType(nil, convertedInnerType)
return
}
}

return nil, nil
return
}

// ConvertValueToEntitlements converts the input value into a version compatible with the new entitlements feature,
// with the same members/operations accessible on any references as would have been accessible in the past.
func ConvertValueToEntitlements(
inter *interpreter.Interpreter,
v interpreter.Value,
) (
interpreter.Value,
error,
) {
func (m EntitlementsMigration) ConvertValueToEntitlements(v interpreter.Value) (interpreter.Value, error) {
inter := m.Interpreter

switch v := v.(type) {

case *interpreter.ArrayValue:
elementType := v.Type

entitledElementType, err := ConvertToEntitledType(inter, elementType)
entitledElementType, err := m.ConvertToEntitledType(elementType)
if err != nil {
return nil, err
}
Expand All @@ -246,7 +281,7 @@ func ConvertValueToEntitlements(
case *interpreter.DictionaryValue:
elementType := v.Type

entitledElementType, err := ConvertToEntitledType(inter, elementType)
entitledElementType, err := m.ConvertToEntitledType(elementType)
if err != nil {
return nil, err
}
Expand All @@ -262,7 +297,7 @@ func ConvertValueToEntitlements(
case *interpreter.IDCapabilityValue:
borrowType := v.BorrowType

entitledBorrowType, err := ConvertToEntitledType(inter, borrowType)
entitledBorrowType, err := m.ConvertToEntitledType(borrowType)
if err != nil {
return nil, err
}
Expand All @@ -279,7 +314,7 @@ func ConvertValueToEntitlements(
case *interpreter.PathCapabilityValue: //nolint:staticcheck
borrowType := v.BorrowType

entitledBorrowType, err := ConvertToEntitledType(inter, borrowType)
entitledBorrowType, err := m.ConvertToEntitledType(borrowType)
if err != nil {
return nil, err
}
Expand All @@ -295,7 +330,7 @@ func ConvertValueToEntitlements(
case interpreter.TypeValue:
ty := v.Type

entitledType, err := ConvertToEntitledType(inter, ty)
entitledType, err := m.ConvertToEntitledType(ty)
if err != nil {
return nil, err
}
Expand All @@ -307,7 +342,7 @@ func ConvertValueToEntitlements(
case *interpreter.AccountCapabilityControllerValue:
borrowType := v.BorrowType

entitledBorrowType, err := ConvertToEntitledType(inter, borrowType)
entitledBorrowType, err := m.ConvertToEntitledType(borrowType)
if err != nil {
return nil, err
}
Expand All @@ -323,7 +358,7 @@ func ConvertValueToEntitlements(
case *interpreter.StorageCapabilityControllerValue:
borrowType := v.BorrowType

entitledBorrowType, err := ConvertToEntitledType(inter, borrowType)
entitledBorrowType, err := m.ConvertToEntitledType(borrowType)
if err != nil {
return nil, err
}
Expand All @@ -340,7 +375,7 @@ func ConvertValueToEntitlements(
case interpreter.PathLinkValue: //nolint:staticcheck
borrowType := v.Type

entitledBorrowType, err := ConvertToEntitledType(inter, borrowType)
entitledBorrowType, err := m.ConvertToEntitledType(borrowType)
if err != nil {
return nil, err
}
Expand All @@ -356,7 +391,7 @@ func ConvertValueToEntitlements(
return nil, nil
}

func (mig EntitlementsMigration) Migrate(
func (m EntitlementsMigration) Migrate(
_ interpreter.StorageKey,
_ interpreter.StorageMapKey,
value interpreter.Value,
Expand All @@ -365,9 +400,9 @@ func (mig EntitlementsMigration) Migrate(
interpreter.Value,
error,
) {
return ConvertValueToEntitlements(mig.Interpreter, value)
return m.ConvertValueToEntitlements(value)
}

func (mig EntitlementsMigration) CanSkip(valueType interpreter.StaticType) bool {
func (m EntitlementsMigration) CanSkip(valueType interpreter.StaticType) bool {
return statictypes.CanSkipStaticTypeMigration(valueType)
}
Loading
Loading