diff --git a/runtime/interpreter/interpreter_expression.go b/runtime/interpreter/interpreter_expression.go index 1b86c3f65b..96ebed3636 100644 --- a/runtime/interpreter/interpreter_expression.go +++ b/runtime/interpreter/interpreter_expression.go @@ -245,10 +245,8 @@ func (interpreter *Interpreter) memberExpressionGetterSetter(memberExpression *a // e.g.2: Given T?, this returns (&T)? func (interpreter *Interpreter) getReferenceValue(value Value, semaType sema.Type) Value { switch value.(type) { - case NilValue: + case NilValue, ReferenceValue: // Reference to a nil, should return a nil. - return value - case ReferenceValue: // If the value is already a reference then return the same reference. return value } diff --git a/runtime/sema/account_capability_controller.cdc b/runtime/sema/account_capability_controller.cdc index e1813cfe0d..2d3e9f9b64 100644 --- a/runtime/sema/account_capability_controller.cdc +++ b/runtime/sema/account_capability_controller.cdc @@ -1,4 +1,4 @@ -access(all) struct AccountCapabilityController { +access(all) struct AccountCapabilityController: MemberAccessible { /// An arbitrary "tag" for the controller. /// For example, it could be used to describe the purpose of the capability. diff --git a/runtime/sema/account_capability_controller.gen.go b/runtime/sema/account_capability_controller.gen.go index 2d2d0841a9..a71550bd63 100644 --- a/runtime/sema/account_capability_controller.gen.go +++ b/runtime/sema/account_capability_controller.gen.go @@ -91,16 +91,17 @@ Borrowing from the controlled capability or its copies will return nil. const AccountCapabilityControllerTypeName = "AccountCapabilityController" var AccountCapabilityControllerType = &SimpleType{ - Name: AccountCapabilityControllerTypeName, - QualifiedName: AccountCapabilityControllerTypeName, - TypeID: AccountCapabilityControllerTypeName, - tag: AccountCapabilityControllerTypeTag, - IsResource: false, - Storable: false, - Equatable: false, - Comparable: false, - Exportable: false, - Importable: false, + Name: AccountCapabilityControllerTypeName, + QualifiedName: AccountCapabilityControllerTypeName, + TypeID: AccountCapabilityControllerTypeName, + tag: AccountCapabilityControllerTypeTag, + IsResource: false, + Storable: false, + Equatable: false, + Comparable: false, + Exportable: false, + Importable: false, + MemberAccessible: true, } func init() { diff --git a/runtime/sema/anyresource_type.go b/runtime/sema/anyresource_type.go index d1de774aa8..19e41d3133 100644 --- a/runtime/sema/anyresource_type.go +++ b/runtime/sema/anyresource_type.go @@ -30,6 +30,7 @@ var AnyResourceType = &SimpleType{ Equatable: false, Comparable: false, // The actual returnability of a value is checked at run-time - Exportable: true, - Importable: false, + Exportable: true, + Importable: false, + MemberAccessible: true, } diff --git a/runtime/sema/anystruct_type.go b/runtime/sema/anystruct_type.go index 14548e0eeb..339d81df20 100644 --- a/runtime/sema/anystruct_type.go +++ b/runtime/sema/anystruct_type.go @@ -31,7 +31,8 @@ var AnyStructType = &SimpleType{ Comparable: false, Exportable: true, // The actual importability is checked at runtime - Importable: true, + Importable: true, + MemberAccessible: true, } var AnyStructTypeAnnotation = NewTypeAnnotation(AnyStructType) diff --git a/runtime/sema/authaccount.cdc b/runtime/sema/authaccount.cdc index e987a4094b..391263a5f4 100644 --- a/runtime/sema/authaccount.cdc +++ b/runtime/sema/authaccount.cdc @@ -1,5 +1,5 @@ -access(all) struct AuthAccount { +access(all) struct AuthAccount: MemberAccessible { /// The address of the account. access(all) let address: Address diff --git a/runtime/sema/block.cdc b/runtime/sema/block.cdc index df8562d954..f2b25075b7 100644 --- a/runtime/sema/block.cdc +++ b/runtime/sema/block.cdc @@ -1,5 +1,5 @@ -access(all) struct Block { +access(all) struct Block: MemberAccessible { /// The height of the block. /// diff --git a/runtime/sema/block.gen.go b/runtime/sema/block.gen.go index 0e1b5d11cb..f74a58e29c 100644 --- a/runtime/sema/block.gen.go +++ b/runtime/sema/block.gen.go @@ -66,16 +66,17 @@ It is essentially the hash of the block const BlockTypeName = "Block" var BlockType = &SimpleType{ - Name: BlockTypeName, - QualifiedName: BlockTypeName, - TypeID: BlockTypeName, - tag: BlockTypeTag, - IsResource: false, - Storable: false, - Equatable: false, - Comparable: false, - Exportable: false, - Importable: false, + Name: BlockTypeName, + QualifiedName: BlockTypeName, + TypeID: BlockTypeName, + tag: BlockTypeTag, + IsResource: false, + Storable: false, + Equatable: false, + Comparable: false, + Exportable: false, + Importable: false, + MemberAccessible: true, } func init() { diff --git a/runtime/sema/check_member_expression.go b/runtime/sema/check_member_expression.go index 26d9ee0f63..da1f637776 100644 --- a/runtime/sema/check_member_expression.go +++ b/runtime/sema/check_member_expression.go @@ -112,7 +112,7 @@ func shouldReturnReference(parentType, memberType Type) bool { return false } - return isContainerType(memberType) + return memberType.IsMemberAccessible() } func isReferenceType(typ Type) bool { @@ -121,23 +121,6 @@ func isReferenceType(typ Type) bool { return isReference } -func isContainerType(typ Type) bool { - switch typ := typ.(type) { - case *CompositeType, - *DictionaryType, - ArrayType: - return true - case *OptionalType: - return isContainerType(typ.Type) - default: - switch typ { - case AnyStructType, AnyResourceType: - return true - } - return false - } -} - func (checker *Checker) visitMember(expression *ast.MemberExpression) (accessedType Type, resultingType Type, member *Member, isOptional bool) { memberInfo, ok := checker.Elaboration.MemberExpressionMemberInfo(expression) if ok { diff --git a/runtime/sema/deployedcontract.cdc b/runtime/sema/deployedcontract.cdc index 61f2195aa8..f3e824243a 100644 --- a/runtime/sema/deployedcontract.cdc +++ b/runtime/sema/deployedcontract.cdc @@ -1,5 +1,5 @@ -access(all) struct DeployedContract { +access(all) struct DeployedContract: MemberAccessible { /// The address of the account where the contract is deployed at. access(all) let address: Address diff --git a/runtime/sema/deployedcontract.gen.go b/runtime/sema/deployedcontract.gen.go index 87b3883b84..f473d72c81 100644 --- a/runtime/sema/deployedcontract.gen.go +++ b/runtime/sema/deployedcontract.gen.go @@ -74,16 +74,17 @@ then ` + "`.publicTypes()`" + ` will return an array equivalent to the expressio const DeployedContractTypeName = "DeployedContract" var DeployedContractType = &SimpleType{ - Name: DeployedContractTypeName, - QualifiedName: DeployedContractTypeName, - TypeID: DeployedContractTypeName, - tag: DeployedContractTypeTag, - IsResource: false, - Storable: false, - Equatable: false, - Comparable: false, - Exportable: false, - Importable: false, + Name: DeployedContractTypeName, + QualifiedName: DeployedContractTypeName, + TypeID: DeployedContractTypeName, + tag: DeployedContractTypeTag, + IsResource: false, + Storable: false, + Equatable: false, + Comparable: false, + Exportable: false, + Importable: false, + MemberAccessible: true, } func init() { diff --git a/runtime/sema/gen/main.go b/runtime/sema/gen/main.go index eaf9d25473..5c6d2fdf11 100644 --- a/runtime/sema/gen/main.go +++ b/runtime/sema/gen/main.go @@ -151,6 +151,7 @@ type typeDecl struct { exportable bool comparable bool importable bool + memberAccessible bool memberDeclarations []ast.Declaration nestedTypes []*typeDecl } @@ -423,6 +424,15 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ case "Importable": typeDecl.importable = true + + case "MemberAccessible": + if !canGenerateSimpleType { + panic(fmt.Errorf( + "composite types cannot be explicitly marked as member accessible: %s", + g.currentTypeID(), + )) + } + typeDecl.memberAccessible = true } } @@ -1158,6 +1168,7 @@ func simpleTypeLiteral(ty *typeDecl) dst.Expr { goKeyValue("Comparable", goBoolLit(ty.comparable)), goKeyValue("Exportable", goBoolLit(ty.exportable)), goKeyValue("Importable", goBoolLit(ty.importable)), + goKeyValue("MemberAccessible", goBoolLit(ty.memberAccessible)), } return &dst.UnaryExpr{ diff --git a/runtime/sema/simple_type.go b/runtime/sema/simple_type.go index 049e0b29ed..cf411991e8 100644 --- a/runtime/sema/simple_type.go +++ b/runtime/sema/simple_type.go @@ -50,6 +50,7 @@ type SimpleType struct { Comparable bool Storable bool IsResource bool + MemberAccessible bool } var _ Type = &SimpleType{} @@ -106,6 +107,10 @@ func (t *SimpleType) IsImportable(_ map[*Member]bool) bool { return t.Importable } +func (t *SimpleType) IsMemberAccessible() bool { + return t.MemberAccessible +} + func (*SimpleType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } diff --git a/runtime/sema/storage_capability_controller.cdc b/runtime/sema/storage_capability_controller.cdc index 7d70961630..0f9ba83dba 100644 --- a/runtime/sema/storage_capability_controller.cdc +++ b/runtime/sema/storage_capability_controller.cdc @@ -1,4 +1,4 @@ -access(all) struct StorageCapabilityController { +access(all) struct StorageCapabilityController: MemberAccessible { /// An arbitrary "tag" for the controller. /// For example, it could be used to describe the purpose of the capability. diff --git a/runtime/sema/storage_capability_controller.gen.go b/runtime/sema/storage_capability_controller.gen.go index ac353f248e..ec12b21b75 100644 --- a/runtime/sema/storage_capability_controller.gen.go +++ b/runtime/sema/storage_capability_controller.gen.go @@ -123,16 +123,17 @@ The path may be different or the same as the current path. const StorageCapabilityControllerTypeName = "StorageCapabilityController" var StorageCapabilityControllerType = &SimpleType{ - Name: StorageCapabilityControllerTypeName, - QualifiedName: StorageCapabilityControllerTypeName, - TypeID: StorageCapabilityControllerTypeName, - tag: StorageCapabilityControllerTypeTag, - IsResource: false, - Storable: false, - Equatable: false, - Comparable: false, - Exportable: false, - Importable: false, + Name: StorageCapabilityControllerTypeName, + QualifiedName: StorageCapabilityControllerTypeName, + TypeID: StorageCapabilityControllerTypeName, + tag: StorageCapabilityControllerTypeTag, + IsResource: false, + Storable: false, + Equatable: false, + Comparable: false, + Exportable: false, + Importable: false, + MemberAccessible: true, } func init() { diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 73b53df7ba..097127f9c5 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -135,6 +135,10 @@ type Type interface { // IsComparable returns true if values of the type can be compared IsComparable() bool + // IsMemberAccessible returns true if value of the type can have members values. + // Examples are composite types, restricted types, arrays, dictionaries, etc. + IsMemberAccessible() bool + TypeAnnotationState() TypeAnnotationState RewriteWithRestrictedTypes() (result Type, rewritten bool) @@ -638,6 +642,10 @@ func (*OptionalType) IsComparable() bool { return false } +func (t *OptionalType) IsMemberAccessible() bool { + return t.Type.IsMemberAccessible() +} + func (t *OptionalType) TypeAnnotationState() TypeAnnotationState { return t.Type.TypeAnnotationState() } @@ -843,6 +851,10 @@ func (*GenericType) IsComparable() bool { return false } +func (t *GenericType) IsMemberAccessible() bool { + return false +} + func (*GenericType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -1142,6 +1154,10 @@ func (t *NumericType) IsComparable() bool { return !t.IsSuperType() } +func (t *NumericType) IsMemberAccessible() bool { + return false +} + func (*NumericType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -1342,6 +1358,10 @@ func (t *FixedPointNumericType) IsComparable() bool { return !t.IsSuperType() } +func (t *FixedPointNumericType) IsMemberAccessible() bool { + return false +} + func (*FixedPointNumericType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -2402,6 +2422,10 @@ func (t *VariableSizedType) IsComparable() bool { return t.Type.IsComparable() } +func (t *VariableSizedType) IsMemberAccessible() bool { + return true +} + func (t *VariableSizedType) TypeAnnotationState() TypeAnnotationState { return t.Type.TypeAnnotationState() } @@ -2552,6 +2576,10 @@ func (t *ConstantSizedType) IsComparable() bool { return t.Type.IsComparable() } +func (t *ConstantSizedType) IsMemberAccessible() bool { + return true +} + func (t *ConstantSizedType) TypeAnnotationState() TypeAnnotationState { return t.Type.TypeAnnotationState() } @@ -3073,6 +3101,10 @@ func (*FunctionType) IsComparable() bool { return false } +func (*FunctionType) IsMemberAccessible() bool { + return false +} + func (t *FunctionType) TypeAnnotationState() TypeAnnotationState { for _, typeParameter := range t.TypeParameters { @@ -4244,8 +4276,12 @@ func (*CompositeType) IsComparable() bool { return false } -func (c *CompositeType) TypeAnnotationState() TypeAnnotationState { - if c.Kind == common.CompositeKindAttachment { +func (*CompositeType) IsMemberAccessible() bool { + return true +} + +func (t *CompositeType) TypeAnnotationState() TypeAnnotationState { + if t.Kind == common.CompositeKindAttachment { return TypeAnnotationStateDirectAttachmentTypeAnnotation } return TypeAnnotationStateValid @@ -4928,6 +4964,10 @@ func (*InterfaceType) IsComparable() bool { return false } +func (*InterfaceType) IsMemberAccessible() bool { + return true +} + func (*InterfaceType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -5145,6 +5185,10 @@ func (*DictionaryType) IsComparable() bool { return false } +func (*DictionaryType) IsMemberAccessible() bool { + return true +} + func (t *DictionaryType) TypeAnnotationState() TypeAnnotationState { keyTypeAnnotationState := t.KeyType.TypeAnnotationState() if keyTypeAnnotationState != TypeAnnotationStateValid { @@ -5610,6 +5654,10 @@ func (*ReferenceType) IsComparable() bool { return false } +func (*ReferenceType) IsMemberAccessible() bool { + return false +} + func (r *ReferenceType) TypeAnnotationState() TypeAnnotationState { if r.Type.TypeAnnotationState() == TypeAnnotationStateDirectEntitlementTypeAnnotation { return TypeAnnotationStateDirectEntitlementTypeAnnotation @@ -5809,6 +5857,10 @@ func (*AddressType) IsComparable() bool { return false } +func (*AddressType) IsMemberAccessible() bool { + return false +} + func (*AddressType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -6526,6 +6578,10 @@ func (*TransactionType) IsComparable() bool { return false } +func (*TransactionType) IsMemberAccessible() bool { + return false +} + func (*TransactionType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -6763,6 +6819,10 @@ func (t *RestrictedType) IsComparable() bool { return false } +func (*RestrictedType) IsMemberAccessible() bool { + return true +} + func (*RestrictedType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateValid } @@ -7038,6 +7098,10 @@ func (*CapabilityType) IsComparable() bool { return false } +func (*CapabilityType) IsMemberAccessible() bool { + return false +} + func (t *CapabilityType) RewriteWithRestrictedTypes() (Type, bool) { if t.BorrowType == nil { return t, false @@ -7591,6 +7655,10 @@ func (*EntitlementType) IsResourceType() bool { return false } +func (*EntitlementType) IsMemberAccessible() bool { + return false +} + func (*EntitlementType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateDirectEntitlementTypeAnnotation } @@ -7721,6 +7789,10 @@ func (*EntitlementMapType) IsResourceType() bool { return false } +func (*EntitlementMapType) IsMemberAccessible() bool { + return false +} + func (*EntitlementMapType) TypeAnnotationState() TypeAnnotationState { return TypeAnnotationStateDirectEntitlementTypeAnnotation } diff --git a/runtime/tests/checker/member_test.go b/runtime/tests/checker/member_test.go index 289b6c7026..ac6d1d4040 100644 --- a/runtime/tests/checker/member_test.go +++ b/runtime/tests/checker/member_test.go @@ -19,6 +19,8 @@ package checker import ( + "fmt" + "github.com/onflow/cadence/runtime/interpreter" "testing" "github.com/stretchr/testify/assert" @@ -462,7 +464,7 @@ func TestCheckMemberAccess(t *testing.T) { require.NoError(t, err) }) - t.Run("composite reference, field", func(t *testing.T) { + t.Run("composite reference, array field", func(t *testing.T) { t.Parallel() _, err := ParseAndCheck(t, ` @@ -713,4 +715,54 @@ func TestCheckMemberAccess(t *testing.T) { assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) }) + + t.Run("all member types", func(t *testing.T) { + t.Parallel() + + test := func(tt *testing.T, typeName string) { + code := fmt.Sprintf(` + struct Foo { + var a: %[1]s? + + init() { + self.a = nil + } + } + + struct Bar {} + + struct interface I {} + + fun test() { + let foo = Foo() + let fooRef = &foo as &Foo + var a: &%[1]s? = fooRef.a + }`, + + typeName, + ) + + _, err := ParseAndCheck(t, code) + require.NoError(t, err) + } + + types := []string{ + "Bar", + "{I}", + "AnyStruct", + "Block", + } + + // Test all built-in composite types + for i := interpreter.PrimitiveStaticTypeAuthAccount; i < interpreter.PrimitiveStaticType_Count; i++ { + semaType := i.SemaType() + types = append(types, semaType.QualifiedString()) + } + + for _, typeName := range types { + t.Run(typeName, func(t *testing.T) { + test(t, typeName) + }) + } + }) } diff --git a/runtime/tests/interpreter/member_test.go b/runtime/tests/interpreter/member_test.go index 68d6f491e6..87885d0741 100644 --- a/runtime/tests/interpreter/member_test.go +++ b/runtime/tests/interpreter/member_test.go @@ -19,6 +19,7 @@ package interpreter_test import ( + "fmt" "testing" "github.com/stretchr/testify/require" @@ -1024,4 +1025,58 @@ func TestInterpretMemberAccess(t *testing.T) { _, err := inter.Invoke("test") require.NoError(t, err) }) + + t.Run("all member types", func(t *testing.T) { + t.Parallel() + + test := func(tt *testing.T, typeName string) { + code := fmt.Sprintf(` + struct Foo { + var a: %[1]s? + + init() { + self.a = nil + } + } + + struct Bar {} + + struct interface I {} + + fun test() { + let foo = Foo() + let fooRef = &foo as &Foo + var a: &%[1]s? = fooRef.a + }`, + + typeName, + ) + + inter := parseCheckAndInterpret(t, code) + + _, err := inter.Invoke("test") + require.NoError(t, err) + } + + types := []string{ + "Bar", + "{I}", + "[Int]", + "{Bool: String}", + "AnyStruct", + "Block", + } + + // Test all built-in composite types + for i := interpreter.PrimitiveStaticTypeAuthAccount; i < interpreter.PrimitiveStaticType_Count; i++ { + semaType := i.SemaType() + types = append(types, semaType.QualifiedString()) + } + + for _, typeName := range types { + t.Run(typeName, func(t *testing.T) { + test(t, typeName) + }) + } + }) }