diff --git a/runtime/ast/memberindices.go b/runtime/ast/memberindices.go index 93007a67c2..ca21bdba01 100644 --- a/runtime/ast/memberindices.go +++ b/runtime/ast/memberindices.go @@ -63,6 +63,8 @@ type memberIndices struct { _attachments []*AttachmentDeclaration // Use `EnumCases()` instead _enumCases []*EnumCaseDeclaration + // Use `Pragmas()` instead + _pragmas []*PragmaDeclaration } func (i *memberIndices) FieldsByIdentifier(declarations []Declaration) map[string]*FieldDeclaration { @@ -150,6 +152,11 @@ func (i *memberIndices) EnumCases(declarations []Declaration) []*EnumCaseDeclara return i._enumCases } +func (i *memberIndices) Pragmas(declarations []Declaration) []*PragmaDeclaration { + i.once.Do(i.initializer(declarations)) + return i._pragmas +} + func (i *memberIndices) initializer(declarations []Declaration) func() { return func() { i.init(declarations) @@ -184,6 +191,7 @@ func (i *memberIndices) init(declarations []Declaration) { i._entitlementMappingsByIdentifier = make(map[string]*EntitlementMappingDeclaration) i._enumCases = make([]*EnumCaseDeclaration, 0) + i._pragmas = make([]*PragmaDeclaration, 0) for _, declaration := range declarations { switch declaration := declaration.(type) { @@ -225,6 +233,9 @@ func (i *memberIndices) init(declarations []Declaration) { case *EnumCaseDeclaration: i._enumCases = append(i._enumCases, declaration) + + case *PragmaDeclaration: + i._pragmas = append(i._pragmas, declaration) } } } diff --git a/runtime/ast/members.go b/runtime/ast/members.go index a438d67566..30ebfd4773 100644 --- a/runtime/ast/members.go +++ b/runtime/ast/members.go @@ -84,6 +84,10 @@ func (m *Members) EnumCases() []*EnumCaseDeclaration { return m.indices.EnumCases(m.declarations) } +func (m *Members) Pragmas() []*PragmaDeclaration { + return m.indices.Pragmas(m.declarations) +} + func (m *Members) FieldsByIdentifier() map[string]*FieldDeclaration { return m.indices.FieldsByIdentifier(m.declarations) } diff --git a/runtime/contract_update_validation_test.go b/runtime/contract_update_validation_test.go index b8274f6523..e747a87682 100644 --- a/runtime/contract_update_validation_test.go +++ b/runtime/contract_update_validation_test.go @@ -3060,3 +3060,323 @@ func TestRuntimeContractUpdateProgramCaching(t *testing.T) { ) }) } + +func TestPragmaUpdates(t *testing.T) { + t.Parallel() + + testWithValidators(t, "Remove pragma", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + #foo(bar) + #baz + } + ` + + const newCode = ` + access(all) contract Test { + #baz + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + require.NoError(t, err) + }) + + testWithValidators(t, "Remove removedType pragma", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + #removedType(bar) + #baz + } + ` + + const newCode = ` + access(all) contract Test { + #baz + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + var expectedErr *stdlib.TypeRemovalPragmaRemovalError + require.ErrorAs(t, err, &expectedErr) + }) + + testWithValidators(t, "removedType pragma moved into subdeclaration", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + #removedType(bar) + access(all) struct S { + + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) struct S { + #removedType(bar) + } + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + var expectedErr *stdlib.TypeRemovalPragmaRemovalError + require.ErrorAs(t, err, &expectedErr) + }) + + testWithValidators(t, "reorder removedType pragmas", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + #removedType(bar) + #removedType(foo) + } + ` + + const newCode = ` + access(all) contract Test { + #removedType(foo) + #removedType(bar) + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + require.NoError(t, err) + }) + + testWithValidators(t, "malformed removedType pragma integer", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + #baz + } + ` + + const newCode = ` + access(all) contract Test { + #removedType(3) + #baz + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + var expectedErr *stdlib.InvalidTypeRemovalPragmaError + require.ErrorAs(t, err, &expectedErr) + }) + + testWithValidators(t, "malformed removedType qualified name", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + #baz + } + ` + + const newCode = ` + access(all) contract Test { + #removedType(X.Y) + #baz + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + var expectedErr *stdlib.InvalidTypeRemovalPragmaError + require.ErrorAs(t, err, &expectedErr) + }) + + testWithValidators(t, "removedType with zero args", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + } + ` + + const newCode = ` + access(all) contract Test { + #removedType() + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + var expectedErr *stdlib.InvalidTypeRemovalPragmaError + require.ErrorAs(t, err, &expectedErr) + }) + + testWithValidators(t, "removedType with two args", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + } + ` + + const newCode = ` + access(all) contract Test { + #removedType(x, y) + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + var expectedErr *stdlib.InvalidTypeRemovalPragmaError + require.ErrorAs(t, err, &expectedErr) + }) + + testWithValidators(t, "#removedType allows type removal", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + access(all) resource R {} + } + ` + + const newCode = ` + access(all) contract Test { + #removedType(R) + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + require.NoError(t, err) + }) + + testWithValidators(t, "#removedType allows two type removals", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + access(all) resource R {} + access(all) struct interface I {} + } + ` + + const newCode = ` + access(all) contract Test { + #removedType(R) + #removedType(I) + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + require.NoError(t, err) + }) + + testWithValidators(t, "#removedType can be added", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + #removedType(I) + access(all) resource R {} + } + ` + + const newCode = ` + access(all) contract Test { + #removedType(R) + #removedType(I) + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + require.NoError(t, err) + }) + + testWithValidators(t, "#removedType can be added without removing a type", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + } + ` + + const newCode = ` + access(all) contract Test { + #removedType(X) + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + require.NoError(t, err) + }) + + testWithValidators(t, "declarations cannot co-exist with removed type of the same name, composite", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + access(all) resource R {} + } + ` + + const newCode = ` + access(all) contract Test { + #removedType(R) + access(all) resource R {} + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + var expectedErr *stdlib.UseOfRemovedTypeError + require.ErrorAs(t, err, &expectedErr) + }) + + testWithValidators(t, "declarations cannot co-exist with removed type of the same name, interface", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + access(all) resource interface R {} + } + ` + + const newCode = ` + access(all) contract Test { + #removedType(R) + access(all) resource interface R {} + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + var expectedErr *stdlib.UseOfRemovedTypeError + require.ErrorAs(t, err, &expectedErr) + }) + + testWithValidators(t, "declarations cannot co-exist with removed type of the same name, attachment", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + access(all) attachment R for AnyResource {} + } + ` + + const newCode = ` + access(all) contract Test { + #removedType(R) + access(all) attachment R for AnyResource {} + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + var expectedErr *stdlib.UseOfRemovedTypeError + require.ErrorAs(t, err, &expectedErr) + }) + + testWithValidators(t, "#removedType is only scoped to the current declaration, inner", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + access(all) resource R {} + access(all) struct S {} + } + ` + + const newCode = ` + access(all) contract Test { + access(all) struct S { + #removedType(R) + } + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + var expectedErr *stdlib.MissingDeclarationError + require.ErrorAs(t, err, &expectedErr) + }) +} diff --git a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go index b05311f968..867026a8d6 100644 --- a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go +++ b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go @@ -23,6 +23,7 @@ import ( "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/common/orderedmap" "github.com/onflow/cadence/runtime/errors" "github.com/onflow/cadence/runtime/interpreter" "github.com/onflow/cadence/runtime/sema" @@ -741,6 +742,7 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) checkNestedDeclarationR nestedDeclaration ast.Declaration, oldContainingDeclaration ast.Declaration, newContainingDeclaration ast.Declaration, + removedTypes *orderedmap.OrderedMap[string, struct{}], ) { // enums can be removed from contract interfaces, as they have no interface equivalent and are not @@ -755,6 +757,7 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) checkNestedDeclarationR nestedDeclaration, oldContainingDeclaration, newContainingDeclaration, + removedTypes, ) } diff --git a/runtime/stdlib/contract_update_validation.go b/runtime/stdlib/contract_update_validation.go index 9831e7a6f2..d43840d9dd 100644 --- a/runtime/stdlib/contract_update_validation.go +++ b/runtime/stdlib/contract_update_validation.go @@ -24,6 +24,7 @@ import ( "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/common/orderedmap" "github.com/onflow/cadence/runtime/errors" ) @@ -41,6 +42,7 @@ type UpdateValidator interface { nestedDeclaration ast.Declaration, oldContainingDeclaration ast.Declaration, newContainingDeclaration ast.Declaration, + removedTypes *orderedmap.OrderedMap[string, struct{}], ) getAccountContractNames(address common.Address) ([]string, error) @@ -215,6 +217,39 @@ func (validator *ContractUpdateValidator) hasErrors() bool { return len(validator.errors) > 0 } +func collectRemovedTypePragmas(validator UpdateValidator, pragmas []*ast.PragmaDeclaration) *orderedmap.OrderedMap[string, struct{}] { + removedTypes := orderedmap.New[orderedmap.OrderedMap[string, struct{}]](len(pragmas)) + + for _, pragma := range pragmas { + invocationExpression, isInvocation := pragma.Expression.(*ast.InvocationExpression) + if !isInvocation { + continue + } + invokedIdentifier, isIdentifier := invocationExpression.InvokedExpression.(*ast.IdentifierExpression) + if !isIdentifier || invokedIdentifier.Identifier.Identifier != "removedType" { + continue + } + if len(invocationExpression.Arguments) != 1 { + validator.report(&InvalidTypeRemovalPragmaError{ + Expression: pragma.Expression, + Range: ast.NewUnmeteredRangeFromPositioned(pragma.Expression), + }) + continue + } + removedTypeName, isIdentifer := invocationExpression.Arguments[0].Expression.(*ast.IdentifierExpression) + if !isIdentifer { + validator.report(&InvalidTypeRemovalPragmaError{ + Expression: pragma.Expression, + Range: ast.NewUnmeteredRangeFromPositioned(pragma.Expression), + }) + continue + } + removedTypes.Set(removedTypeName.Identifier.Identifier, struct{}{}) + } + + return removedTypes +} + func checkDeclarationUpdatability( validator UpdateValidator, oldDeclaration ast.Declaration, @@ -310,12 +345,18 @@ func (validator *ContractUpdateValidator) checkNestedDeclarationRemoval( nestedDeclaration ast.Declaration, _ ast.Declaration, newContainingDeclaration ast.Declaration, + removedTypes *orderedmap.OrderedMap[string, struct{}], ) { // OK to remove events - they are not stored if nestedDeclaration.DeclarationKind() == common.DeclarationKindEvent { return } + // OK to remove a type if it is included in a #removedType pragma + if removedTypes.Contains(nestedDeclaration.DeclarationIdentifier().Identifier) { + return + } + validator.report(&MissingDeclarationError{ Name: nestedDeclaration.DeclarationIdentifier().Identifier, Kind: nestedDeclaration.DeclarationKind(), @@ -334,6 +375,19 @@ func (validator *ContractUpdateValidator) oldTypeID(oldType *ast.NominalType) co return oldImportLocation.TypeID(nil, qualifiedIdentifier) } +func checkTypeNotRemoved( + validator UpdateValidator, + newDeclaration ast.Declaration, + removedTypes *orderedmap.OrderedMap[string, struct{}], +) { + if removedTypes.Contains(newDeclaration.DeclarationIdentifier().Identifier) { + validator.report(&UseOfRemovedTypeError{ + Declaration: newDeclaration, + Range: ast.NewUnmeteredRangeFromPositioned(newDeclaration), + }) + } +} + func checkNestedDeclarations( validator UpdateValidator, oldDeclaration ast.Declaration, @@ -341,11 +395,26 @@ func checkNestedDeclarations( checkConformance checkConformanceFunc, ) { + // process pragmas first, as they determine whether types can later be removed + oldRemovedTypes := collectRemovedTypePragmas(validator, oldDeclaration.DeclarationMembers().Pragmas()) + removedTypes := collectRemovedTypePragmas(validator, newDeclaration.DeclarationMembers().Pragmas()) + + // #typeRemoval pragmas cannot be removed, so any that appear in the old program must appear in the new program + // they can however, be added, so use the new program's type removals for the purposes of checking the upgrade + oldRemovedTypes.Foreach(func(oldRemovedType string, _ struct{}) { + if !removedTypes.Contains(oldRemovedType) { + validator.report(&TypeRemovalPragmaRemovalError{ + RemovedType: oldRemovedType, + }) + } + }) + oldNominalTypeDecls := getNestedNominalTypeDecls(oldDeclaration) // Check nested structs, enums, etc. newNestedCompositeDecls := newDeclaration.DeclarationMembers().Composites() for _, newNestedDecl := range newNestedCompositeDecls { + checkTypeNotRemoved(validator, newNestedDecl, removedTypes) oldNestedDecl, found := oldNominalTypeDecls[newNestedDecl.Identifier.Identifier] if !found { // Then it's a new declaration @@ -361,6 +430,7 @@ func checkNestedDeclarations( // Check nested attachments, etc. newNestedAttachmentDecls := newDeclaration.DeclarationMembers().Attachments() for _, newNestedDecl := range newNestedAttachmentDecls { + checkTypeNotRemoved(validator, newNestedDecl, removedTypes) oldNestedDecl, found := oldNominalTypeDecls[newNestedDecl.Identifier.Identifier] if !found { // Then it's a new declaration @@ -376,6 +446,7 @@ func checkNestedDeclarations( // Check nested interfaces. newNestedInterfaces := newDeclaration.DeclarationMembers().Interfaces() for _, newNestedDecl := range newNestedInterfaces { + checkTypeNotRemoved(validator, newNestedDecl, removedTypes) oldNestedDecl, found := oldNominalTypeDecls[newNestedDecl.Identifier.Identifier] if !found { // Then this is a new declaration. @@ -404,7 +475,7 @@ func checkNestedDeclarations( }) for _, declaration := range missingDeclarations { - validator.checkNestedDeclarationRemoval(declaration, oldDeclaration, newDeclaration) + validator.checkNestedDeclarationRemoval(declaration, oldDeclaration, newDeclaration, removedTypes) } // Check enum-cases, if there are any. @@ -774,3 +845,56 @@ func (e *MissingDeclarationError) Error() string { e.Name, ) } + +// InvalidTypeRemovalPragmaError is reported during a contract update +// if a malformed #removedType pragma is encountered +type InvalidTypeRemovalPragmaError struct { + Expression ast.Expression + ast.Range +} + +var _ errors.UserError = &InvalidTypeRemovalPragmaError{} + +func (*InvalidTypeRemovalPragmaError) IsUserError() {} + +func (e *InvalidTypeRemovalPragmaError) Error() string { + return fmt.Sprintf( + "invalid #removedType pragma: %s", + e.Expression.String(), + ) +} + +// UseOfRemovedTypeError is reported during a contract update +// if a type is encountered that is also in a #removedType pragma +type UseOfRemovedTypeError struct { + Declaration ast.Declaration + ast.Range +} + +var _ errors.UserError = &UseOfRemovedTypeError{} + +func (*UseOfRemovedTypeError) IsUserError() {} + +func (e *UseOfRemovedTypeError) Error() string { + return fmt.Sprintf( + "cannot declare %s, type has been removed with a #removedType pragma", + e.Declaration.DeclarationIdentifier(), + ) +} + +// TypeRemovalPragmaRemovalError is reported during a contract update +// if a #removedType pragma is removed +type TypeRemovalPragmaRemovalError struct { + RemovedType string +} + +var _ errors.UserError = &TypeRemovalPragmaRemovalError{} + +func (*TypeRemovalPragmaRemovalError) IsUserError() {} + +func (e *TypeRemovalPragmaRemovalError) Error() string { + return fmt.Sprintf( + "missing #removedType pragma for %s", + e.RemovedType, + ) +}