Skip to content

Commit

Permalink
Merge pull request #3232 from onflow/sainati/entitlement-type-migrati…
Browse files Browse the repository at this point in the history
…on-cache

Cache results of entitlement type conversion
  • Loading branch information
dsainati1 authored Apr 24, 2024
2 parents ba62ba7 + 6f4698e commit 1ea9852
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 55 deletions.
41 changes: 41 additions & 0 deletions migrations/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Cadence - The resource-oriented smart contract programming language
*
* Copyright Dapper Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package migrations

import (
"sync"

"github.com/onflow/cadence/runtime/interpreter"
)

type StaticTypeCache struct {
entries sync.Map
}

func (c *StaticTypeCache) Get(typeID interpreter.TypeID) (interpreter.StaticType, bool) {
v, ok := c.entries.Load(typeID)
if !ok {
return nil, false
}
return v.(interpreter.StaticType), true
}

func (c *StaticTypeCache) Set(typeID interpreter.TypeID, ty interpreter.StaticType) {
c.entries.Store(typeID, ty)
}
129 changes: 82 additions & 47 deletions migrations/entitlements/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,24 @@ import (
)

type EntitlementsMigration struct {
Interpreter *interpreter.Interpreter
Interpreter *interpreter.Interpreter
migratedTypeCache *migrations.StaticTypeCache
}

var _ migrations.ValueMigration = EntitlementsMigration{}

func NewEntitlementsMigration(inter *interpreter.Interpreter) EntitlementsMigration {
return EntitlementsMigration{Interpreter: inter}
return NewEntitlementsMigrationWithCache(inter, &migrations.StaticTypeCache{})
}

func NewEntitlementsMigrationWithCache(
inter *interpreter.Interpreter,
migratedTypeCache *migrations.StaticTypeCache,
) EntitlementsMigration {
return EntitlementsMigration{
Interpreter: inter,
migratedTypeCache: migratedTypeCache,
}
}

func (EntitlementsMigration) Name() string {
Expand All @@ -57,12 +68,11 @@ func (EntitlementsMigration) Domains() map[string]struct{} {
// where `Entitlements(I)` is defined as the result of `T.SupportedEntitlements()`
//
// TODO: functions?
func ConvertToEntitledType(
inter *interpreter.Interpreter,
func (m EntitlementsMigration) ConvertToEntitledType(
staticType interpreter.StaticType,
) (
interpreter.StaticType,
error,
resultType interpreter.StaticType,
conversionErr error,
) {
if staticType == nil {
return nil, nil
Expand All @@ -72,14 +82,30 @@ func ConvertToEntitledType(
return nil, fmt.Errorf("cannot migrate deprecated type: %s", staticType)
}

inter := m.Interpreter
migratedTypeCache := m.migratedTypeCache

staticTypeID := staticType.ID()

if migratedType, exists := migratedTypeCache.Get(staticTypeID); exists {
return migratedType, nil
}

defer func() {
if resultType != nil && conversionErr == nil {
migratedTypeCache.Set(staticTypeID, resultType)
}
}()

switch t := staticType.(type) {
case *interpreter.ReferenceStaticType:

referencedType := t.ReferencedType

convertedReferencedType, err := ConvertToEntitledType(inter, referencedType)
convertedReferencedType, err := m.ConvertToEntitledType(referencedType)
if err != nil {
return nil, err
conversionErr = err
return
}

var returnNew bool
Expand Down Expand Up @@ -137,100 +163,109 @@ func ConvertToEntitledType(
}

if returnNew {
return interpreter.NewReferenceStaticType(nil, auth, referencedType), nil
resultType = interpreter.NewReferenceStaticType(nil, auth, referencedType)
return
}

case *interpreter.CapabilityStaticType:
convertedBorrowType, err := ConvertToEntitledType(inter, t.BorrowType)
convertedBorrowType, err := m.ConvertToEntitledType(t.BorrowType)
if err != nil {
return nil, err
conversionErr = err
return
}

if convertedBorrowType != nil {
return interpreter.NewCapabilityStaticType(nil, convertedBorrowType), nil
resultType = interpreter.NewCapabilityStaticType(nil, convertedBorrowType)
return
}

case *interpreter.VariableSizedStaticType:
elementType := t.Type

convertedElementType, err := ConvertToEntitledType(inter, elementType)
convertedElementType, err := m.ConvertToEntitledType(elementType)
if err != nil {
return nil, err
conversionErr = err
return
}

if convertedElementType != nil {
return interpreter.NewVariableSizedStaticType(nil, convertedElementType), nil
resultType = interpreter.NewVariableSizedStaticType(nil, convertedElementType)
return
}

case *interpreter.ConstantSizedStaticType:
elementType := t.Type

convertedElementType, err := ConvertToEntitledType(inter, elementType)
convertedElementType, err := m.ConvertToEntitledType(elementType)
if err != nil {
return nil, err
conversionErr = err
return
}

if convertedElementType != nil {
return interpreter.NewConstantSizedStaticType(nil, convertedElementType, t.Size), nil
resultType = interpreter.NewConstantSizedStaticType(nil, convertedElementType, t.Size)
return
}

case *interpreter.DictionaryStaticType:
keyType := t.KeyType

convertedKeyType, err := ConvertToEntitledType(inter, keyType)
convertedKeyType, err := m.ConvertToEntitledType(keyType)
if err != nil {
return nil, err
conversionErr = err
return
}

valueType := t.ValueType

convertedValueType, err := ConvertToEntitledType(inter, valueType)
convertedValueType, err := m.ConvertToEntitledType(valueType)
if err != nil {
return nil, err
conversionErr = err
return
}

if convertedKeyType != nil {
if convertedValueType != nil {
return interpreter.NewDictionaryStaticType(nil, convertedKeyType, convertedValueType), nil
resultType = interpreter.NewDictionaryStaticType(nil, convertedKeyType, convertedValueType)
return
} else {
return interpreter.NewDictionaryStaticType(nil, convertedKeyType, valueType), nil
resultType = interpreter.NewDictionaryStaticType(nil, convertedKeyType, valueType)
return
}
} else if convertedValueType != nil {
return interpreter.NewDictionaryStaticType(nil, keyType, convertedValueType), nil
resultType = interpreter.NewDictionaryStaticType(nil, keyType, convertedValueType)
return
}

case *interpreter.OptionalStaticType:
innerType := t.Type

convertedInnerType, err := ConvertToEntitledType(inter, innerType)
convertedInnerType, err := m.ConvertToEntitledType(innerType)
if err != nil {
return nil, err
conversionErr = err
return
}

if convertedInnerType != nil {
return interpreter.NewOptionalStaticType(nil, convertedInnerType), nil
resultType = interpreter.NewOptionalStaticType(nil, convertedInnerType)
return
}
}

return nil, nil
return
}

// ConvertValueToEntitlements converts the input value into a version compatible with the new entitlements feature,
// with the same members/operations accessible on any references as would have been accessible in the past.
func ConvertValueToEntitlements(
inter *interpreter.Interpreter,
v interpreter.Value,
) (
interpreter.Value,
error,
) {
func (m EntitlementsMigration) ConvertValueToEntitlements(v interpreter.Value) (interpreter.Value, error) {
inter := m.Interpreter

switch v := v.(type) {

case *interpreter.ArrayValue:
elementType := v.Type

entitledElementType, err := ConvertToEntitledType(inter, elementType)
entitledElementType, err := m.ConvertToEntitledType(elementType)
if err != nil {
return nil, err
}
Expand All @@ -246,7 +281,7 @@ func ConvertValueToEntitlements(
case *interpreter.DictionaryValue:
elementType := v.Type

entitledElementType, err := ConvertToEntitledType(inter, elementType)
entitledElementType, err := m.ConvertToEntitledType(elementType)
if err != nil {
return nil, err
}
Expand All @@ -262,7 +297,7 @@ func ConvertValueToEntitlements(
case *interpreter.IDCapabilityValue:
borrowType := v.BorrowType

entitledBorrowType, err := ConvertToEntitledType(inter, borrowType)
entitledBorrowType, err := m.ConvertToEntitledType(borrowType)
if err != nil {
return nil, err
}
Expand All @@ -279,7 +314,7 @@ func ConvertValueToEntitlements(
case *interpreter.PathCapabilityValue: //nolint:staticcheck
borrowType := v.BorrowType

entitledBorrowType, err := ConvertToEntitledType(inter, borrowType)
entitledBorrowType, err := m.ConvertToEntitledType(borrowType)
if err != nil {
return nil, err
}
Expand All @@ -295,7 +330,7 @@ func ConvertValueToEntitlements(
case interpreter.TypeValue:
ty := v.Type

entitledType, err := ConvertToEntitledType(inter, ty)
entitledType, err := m.ConvertToEntitledType(ty)
if err != nil {
return nil, err
}
Expand All @@ -307,7 +342,7 @@ func ConvertValueToEntitlements(
case *interpreter.AccountCapabilityControllerValue:
borrowType := v.BorrowType

entitledBorrowType, err := ConvertToEntitledType(inter, borrowType)
entitledBorrowType, err := m.ConvertToEntitledType(borrowType)
if err != nil {
return nil, err
}
Expand All @@ -323,7 +358,7 @@ func ConvertValueToEntitlements(
case *interpreter.StorageCapabilityControllerValue:
borrowType := v.BorrowType

entitledBorrowType, err := ConvertToEntitledType(inter, borrowType)
entitledBorrowType, err := m.ConvertToEntitledType(borrowType)
if err != nil {
return nil, err
}
Expand All @@ -340,7 +375,7 @@ func ConvertValueToEntitlements(
case interpreter.PathLinkValue: //nolint:staticcheck
borrowType := v.Type

entitledBorrowType, err := ConvertToEntitledType(inter, borrowType)
entitledBorrowType, err := m.ConvertToEntitledType(borrowType)
if err != nil {
return nil, err
}
Expand All @@ -356,7 +391,7 @@ func ConvertValueToEntitlements(
return nil, nil
}

func (mig EntitlementsMigration) Migrate(
func (m EntitlementsMigration) Migrate(
_ interpreter.StorageKey,
_ interpreter.StorageMapKey,
value interpreter.Value,
Expand All @@ -365,9 +400,9 @@ func (mig EntitlementsMigration) Migrate(
interpreter.Value,
error,
) {
return ConvertValueToEntitlements(mig.Interpreter, value)
return m.ConvertValueToEntitlements(value)
}

func (mig EntitlementsMigration) CanSkip(valueType interpreter.StaticType) bool {
func (m EntitlementsMigration) CanSkip(valueType interpreter.StaticType) bool {
return statictypes.CanSkipStaticTypeMigration(valueType)
}
Loading

0 comments on commit 1ea9852

Please sign in to comment.