From 5e50c45112dea0c3e415cef9554d8d871f592e54 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Wed, 28 Feb 2024 17:26:11 +0530 Subject: [PATCH] Add test for nested dictionary migration --- migrations/migration.go | 38 ++-- .../statictypes/statictype_migration.go | 1 + .../statictypes/statictype_migration_test.go | 180 ++++++++++++++++++ 3 files changed, 205 insertions(+), 14 deletions(-) diff --git a/migrations/migration.go b/migrations/migration.go index b06b559324..90fa0d829e 100644 --- a/migrations/migration.go +++ b/migrations/migration.go @@ -138,7 +138,7 @@ func (m *StorageMigration) MigrateNestedValue( value interpreter.Value, valueMigrations []ValueMigration, reporter Reporter, -) (newValue interpreter.Value) { +) (migratedValue interpreter.Value) { defer func() { // Here it catches the panics that may occur at the framework level, @@ -164,9 +164,11 @@ func (m *StorageMigration) MigrateNestedValue( } }() - switch value := value.(type) { + // Visit the children first, and migrate them. + // i.e: depth-first traversal + switch typedValue := value.(type) { case *interpreter.SomeValue: - innerValue := value.InnerValue(m.interpreter, emptyLocationRange) + innerValue := typedValue.InnerValue(m.interpreter, emptyLocationRange) newInnerValue := m.MigrateNestedValue( storageKey, storageMapKey, @@ -175,13 +177,14 @@ func (m *StorageMigration) MigrateNestedValue( reporter, ) if newInnerValue != nil { - return interpreter.NewSomeValueNonCopying(m.interpreter, newInnerValue) - } + migratedValue = interpreter.NewSomeValueNonCopying(m.interpreter, newInnerValue) - return + // chain the migrations + value = migratedValue + } case *interpreter.ArrayValue: - array := value + array := typedValue // Migrate array elements count := array.Count() @@ -205,7 +208,7 @@ func (m *StorageMigration) MigrateNestedValue( } case *interpreter.CompositeValue: - composite := value + composite := typedValue // Read the field names first, so the iteration wouldn't be affected // by the modification of the nested values. @@ -243,7 +246,7 @@ func (m *StorageMigration) MigrateNestedValue( } case *interpreter.DictionaryValue: - dictionary := value + dictionary := typedValue type keyValuePair struct { key, value interpreter.Value @@ -336,24 +339,31 @@ func (m *StorageMigration) MigrateNestedValue( } case *interpreter.PublishedValue: - innerValue := value.Value + publishedValue := typedValue newInnerValue := m.MigrateNestedValue( storageKey, storageMapKey, - innerValue, + publishedValue.Value, valueMigrations, reporter, ) if newInnerValue != nil { newInnerCapability := newInnerValue.(*interpreter.IDCapabilityValue) - return interpreter.NewPublishedValue( + migratedValue = interpreter.NewPublishedValue( m.interpreter, - value.Recipient, + publishedValue.Recipient, newInnerCapability, ) + + // chain the migrations + value = migratedValue } } + // Once the children are migrated, then migrate the current/wrapper value. + // Result of each migration is passed as the input to the next migration. + // i.e: A single value is migrated by all the migrations, before moving onto the next value. + for _, migration := range valueMigrations { convertedValue, err := m.migrate( migration, @@ -395,7 +405,7 @@ func (m *StorageMigration) MigrateNestedValue( // Chain the migrations. value = convertedValue - newValue = convertedValue + migratedValue = convertedValue if reporter != nil { reporter.Migrated( diff --git a/migrations/statictypes/statictype_migration.go b/migrations/statictypes/statictype_migration.go index a0b7c7132d..35f54e1cfd 100644 --- a/migrations/statictypes/statictype_migration.go +++ b/migrations/statictypes/statictype_migration.go @@ -157,6 +157,7 @@ func (m *StaticTypeMigration) Migrate( var keysAndValues []interpreter.Value iterator := value.Iterator() + for { keyValue, value := iterator.Next(inter) if keyValue == nil { diff --git a/migrations/statictypes/statictype_migration_test.go b/migrations/statictypes/statictype_migration_test.go index 189345cff6..a16194e8a5 100644 --- a/migrations/statictypes/statictype_migration_test.go +++ b/migrations/statictypes/statictype_migration_test.go @@ -24,6 +24,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/onflow/atree" + "github.com/onflow/cadence/migrations" "github.com/onflow/cadence/runtime" "github.com/onflow/cadence/runtime/common" @@ -580,3 +582,181 @@ func TestStaticTypeMigration(t *testing.T) { }) } + +func TestMigratingNestedContainers(t *testing.T) { + t.Parallel() + + migrate := func( + t *testing.T, + staticTypeMigration *StaticTypeMigration, + storage *runtime.Storage, + inter *interpreter.Interpreter, + value interpreter.Value, + ) interpreter.Value { + + // Store values + + storageMapKey := interpreter.StringStorageMapKey("test_type_value") + storageDomain := common.PathDomainStorage.Identifier() + + value = value.Transfer( + inter, + interpreter.EmptyLocationRange, + atree.Address(testAddress), + false, + nil, + nil, + ) + + inter.WriteStored( + testAddress, + storageDomain, + storageMapKey, + value, + ) + + err := storage.Commit(inter, true) + require.NoError(t, err) + + // Migrate + + migration := migrations.NewStorageMigration(inter, storage) + + reporter := newTestReporter() + + migration.Migrate( + &migrations.AddressSliceIterator{ + Addresses: []common.Address{ + testAddress, + }, + }, + migration.NewValueMigrationsPathMigrator( + reporter, + staticTypeMigration, + ), + ) + + err = migration.Commit() + require.NoError(t, err) + + require.Empty(t, reporter.errors) + + storageMap := storage.GetStorageMap( + testAddress, + storageDomain, + false, + ) + require.NotNil(t, storageMap) + require.Equal(t, uint64(1), storageMap.Count()) + + result := storageMap.ReadValue(nil, storageMapKey) + require.NotNil(t, value) + + return result + } + + t.Run("nested dictionary", func(t *testing.T) { + t.Parallel() + + staticTypeMigration := NewStaticTypeMigration() + + locationRange := interpreter.EmptyLocationRange + + ledger := NewTestLedger(nil, nil) + storage := runtime.NewStorage(ledger, nil) + + inter, err := interpreter.NewInterpreter( + nil, + utils.TestLocation, + &interpreter.Config{ + Storage: storage, + AtreeValueValidationEnabled: false, + AtreeStorageValidationEnabled: false, + }, + ) + require.NoError(t, err) + + storedValue := interpreter.NewDictionaryValue( + inter, + locationRange, + interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeString, + interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeString, + interpreter.NewCapabilityStaticType( + nil, + interpreter.PrimitiveStaticTypePublicAccount, //nolint:staticcheck + ), + ), + ), + interpreter.NewUnmeteredStringValue("key"), + interpreter.NewDictionaryValue( + inter, + locationRange, + interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeString, + interpreter.NewCapabilityStaticType( + nil, + interpreter.PrimitiveStaticTypePublicAccount, //nolint:staticcheck + ), + ), + interpreter.NewUnmeteredStringValue("key"), + interpreter.NewCapabilityValue( + nil, + interpreter.NewUnmeteredUInt64Value(1234), + interpreter.NewAddressValue(nil, common.ZeroAddress), + interpreter.PrimitiveStaticTypePublicAccount, //nolint:staticcheck + ), + ), + ) + + actual := migrate(t, + staticTypeMigration, + storage, + inter, + storedValue, + ) + + expected := interpreter.NewDictionaryValue( + inter, + locationRange, + interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeString, + interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeString, + interpreter.NewCapabilityStaticType( + nil, + unauthorizedAccountReferenceType, + ), + ), + ), + interpreter.NewUnmeteredStringValue("key"), + interpreter.NewDictionaryValue( + inter, + locationRange, + interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeString, + interpreter.NewCapabilityStaticType( + nil, + unauthorizedAccountReferenceType, + ), + ), + interpreter.NewUnmeteredStringValue("key"), + interpreter.NewCapabilityValue( + nil, + interpreter.NewUnmeteredUInt64Value(1234), + interpreter.NewAddressValue(nil, common.Address{}), + unauthorizedAccountReferenceType, + ), + ), + ) + + utils.AssertValuesEqual(t, inter, expected, actual) + }) +}