Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for nested container value migration #3142

Merged
merged 1 commit into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions migrations/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (m *StorageMigration) MigrateNestedValue(
value interpreter.Value,
valueMigrations []ValueMigration,
reporter Reporter,
) (newValue interpreter.Value) {
) (migratedValue interpreter.Value) {

defer func() {
// Here it catches the panics that may occur at the framework level,
Expand All @@ -164,9 +164,11 @@ func (m *StorageMigration) MigrateNestedValue(
}
}()

switch value := value.(type) {
// Visit the children first, and migrate them.
// i.e: depth-first traversal
switch typedValue := value.(type) {
case *interpreter.SomeValue:
innerValue := value.InnerValue(m.interpreter, emptyLocationRange)
innerValue := typedValue.InnerValue(m.interpreter, emptyLocationRange)
newInnerValue := m.MigrateNestedValue(
storageKey,
storageMapKey,
Expand All @@ -175,13 +177,14 @@ func (m *StorageMigration) MigrateNestedValue(
reporter,
)
if newInnerValue != nil {
return interpreter.NewSomeValueNonCopying(m.interpreter, newInnerValue)
}
migratedValue = interpreter.NewSomeValueNonCopying(m.interpreter, newInnerValue)

return
// chain the migrations
value = migratedValue
}

case *interpreter.ArrayValue:
array := value
array := typedValue

// Migrate array elements
count := array.Count()
Expand All @@ -205,7 +208,7 @@ func (m *StorageMigration) MigrateNestedValue(
}

case *interpreter.CompositeValue:
composite := value
composite := typedValue

// Read the field names first, so the iteration wouldn't be affected
// by the modification of the nested values.
Expand Down Expand Up @@ -243,7 +246,7 @@ func (m *StorageMigration) MigrateNestedValue(
}

case *interpreter.DictionaryValue:
dictionary := value
dictionary := typedValue

type keyValuePair struct {
key, value interpreter.Value
Expand Down Expand Up @@ -336,24 +339,31 @@ func (m *StorageMigration) MigrateNestedValue(
}

case *interpreter.PublishedValue:
innerValue := value.Value
publishedValue := typedValue
newInnerValue := m.MigrateNestedValue(
storageKey,
storageMapKey,
innerValue,
publishedValue.Value,
valueMigrations,
reporter,
)
if newInnerValue != nil {
newInnerCapability := newInnerValue.(*interpreter.IDCapabilityValue)
return interpreter.NewPublishedValue(
migratedValue = interpreter.NewPublishedValue(
m.interpreter,
value.Recipient,
publishedValue.Recipient,
newInnerCapability,
)

// chain the migrations
value = migratedValue
}
}

// Once the children are migrated, then migrate the current/wrapper value.
// Result of each migration is passed as the input to the next migration.
// i.e: A single value is migrated by all the migrations, before moving onto the next value.

for _, migration := range valueMigrations {
convertedValue, err := m.migrate(
migration,
Expand Down Expand Up @@ -395,7 +405,7 @@ func (m *StorageMigration) MigrateNestedValue(
// Chain the migrations.
value = convertedValue

newValue = convertedValue
migratedValue = convertedValue

if reporter != nil {
reporter.Migrated(
Expand Down
1 change: 1 addition & 0 deletions migrations/statictypes/statictype_migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ func (m *StaticTypeMigration) Migrate(
var keysAndValues []interpreter.Value

iterator := value.Iterator()

for {
keyValue, value := iterator.Next(inter)
if keyValue == nil {
Expand Down
180 changes: 180 additions & 0 deletions migrations/statictypes/statictype_migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/onflow/atree"

"github.com/onflow/cadence/migrations"
"github.com/onflow/cadence/runtime"
"github.com/onflow/cadence/runtime/common"
Expand Down Expand Up @@ -580,3 +582,181 @@ func TestStaticTypeMigration(t *testing.T) {

})
}

func TestMigratingNestedContainers(t *testing.T) {
t.Parallel()

migrate := func(
t *testing.T,
staticTypeMigration *StaticTypeMigration,
storage *runtime.Storage,
inter *interpreter.Interpreter,
value interpreter.Value,
) interpreter.Value {

// Store values

storageMapKey := interpreter.StringStorageMapKey("test_type_value")
storageDomain := common.PathDomainStorage.Identifier()

value = value.Transfer(
inter,
interpreter.EmptyLocationRange,
atree.Address(testAddress),
false,
nil,
nil,
)

inter.WriteStored(
testAddress,
storageDomain,
storageMapKey,
value,
)

err := storage.Commit(inter, true)
require.NoError(t, err)

// Migrate

migration := migrations.NewStorageMigration(inter, storage)

reporter := newTestReporter()

migration.Migrate(
&migrations.AddressSliceIterator{
Addresses: []common.Address{
testAddress,
},
},
migration.NewValueMigrationsPathMigrator(
reporter,
staticTypeMigration,
),
)

err = migration.Commit()
require.NoError(t, err)

require.Empty(t, reporter.errors)

storageMap := storage.GetStorageMap(
testAddress,
storageDomain,
false,
)
require.NotNil(t, storageMap)
require.Equal(t, uint64(1), storageMap.Count())

result := storageMap.ReadValue(nil, storageMapKey)
require.NotNil(t, value)

return result
}

t.Run("nested dictionary", func(t *testing.T) {
t.Parallel()

staticTypeMigration := NewStaticTypeMigration()

locationRange := interpreter.EmptyLocationRange

ledger := NewTestLedger(nil, nil)
storage := runtime.NewStorage(ledger, nil)

inter, err := interpreter.NewInterpreter(
nil,
utils.TestLocation,
&interpreter.Config{
Storage: storage,
AtreeValueValidationEnabled: false,
AtreeStorageValidationEnabled: false,
},
)
require.NoError(t, err)

storedValue := interpreter.NewDictionaryValue(
inter,
locationRange,
interpreter.NewDictionaryStaticType(
nil,
interpreter.PrimitiveStaticTypeString,
interpreter.NewDictionaryStaticType(
nil,
interpreter.PrimitiveStaticTypeString,
interpreter.NewCapabilityStaticType(
nil,
interpreter.PrimitiveStaticTypePublicAccount, //nolint:staticcheck
),
),
),
interpreter.NewUnmeteredStringValue("key"),
interpreter.NewDictionaryValue(
inter,
locationRange,
interpreter.NewDictionaryStaticType(
nil,
interpreter.PrimitiveStaticTypeString,
interpreter.NewCapabilityStaticType(
nil,
interpreter.PrimitiveStaticTypePublicAccount, //nolint:staticcheck
),
),
interpreter.NewUnmeteredStringValue("key"),
interpreter.NewCapabilityValue(
nil,
interpreter.NewUnmeteredUInt64Value(1234),
interpreter.NewAddressValue(nil, common.ZeroAddress),
interpreter.PrimitiveStaticTypePublicAccount, //nolint:staticcheck
),
),
)

actual := migrate(t,
staticTypeMigration,
storage,
inter,
storedValue,
)

expected := interpreter.NewDictionaryValue(
inter,
locationRange,
interpreter.NewDictionaryStaticType(
nil,
interpreter.PrimitiveStaticTypeString,
interpreter.NewDictionaryStaticType(
nil,
interpreter.PrimitiveStaticTypeString,
interpreter.NewCapabilityStaticType(
nil,
unauthorizedAccountReferenceType,
),
),
),
interpreter.NewUnmeteredStringValue("key"),
interpreter.NewDictionaryValue(
inter,
locationRange,
interpreter.NewDictionaryStaticType(
nil,
interpreter.PrimitiveStaticTypeString,
interpreter.NewCapabilityStaticType(
nil,
unauthorizedAccountReferenceType,
),
),
interpreter.NewUnmeteredStringValue("key"),
interpreter.NewCapabilityValue(
nil,
interpreter.NewUnmeteredUInt64Value(1234),
interpreter.NewAddressValue(nil, common.Address{}),
unauthorizedAccountReferenceType,
),
),
)

utils.AssertValuesEqual(t, inter, expected, actual)
})
}
Loading