From ca5044a9ce3f7db451c086fe6db094b195d5eee4 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Fri, 2 Jun 2023 13:26:31 -0400 Subject: [PATCH 01/10] convert capability values on upcast --- runtime/entitlements_test.go | 211 +++++++++++++++++- runtime/interpreter/interpreter.go | 24 +- .../tests/interpreter/entitlements_test.go | 15 +- 3 files changed, 241 insertions(+), 9 deletions(-) diff --git a/runtime/entitlements_test.go b/runtime/entitlements_test.go index eed3e23699..b81b720304 100644 --- a/runtime/entitlements_test.go +++ b/runtime/entitlements_test.go @@ -59,7 +59,7 @@ func TestAccountEntitlementSaveAndLoadSuccess(t *testing.T) { import Test from 0x1 transaction { prepare(signer: AuthAccount) { - let cap = signer.getCapability(/public/foo) + let cap = signer.getCapability(/public/foo) let ref = cap.borrow()! let downcastRef = ref as! auth(Test.X, Test.Y) &Int } @@ -501,3 +501,212 @@ func TestAccountEntitlementNamingConflict(t *testing.T) { var accessError *sema.InvalidAccessError require.ErrorAs(t, errs[0], &accessError) } + +func TestAccountEntitlementCapabilityCasting(t *testing.T) { + t.Parallel() + + storage := newTestLedger(nil, nil) + rt := newTestInterpreterRuntimeWithAttachments() + accountCodes := map[Location][]byte{} + + deployTx := DeploymentTransaction("Test", []byte(` + pub contract Test { + pub entitlement X + pub entitlement Y + + pub resource R {} + + pub fun createR(): @R { + return <-create R() + } + } + `)) + + transaction1 := []byte(` + import Test from 0x1 + transaction { + prepare(signer: AuthAccount) { + let r <- Test.createR() + signer.save(<-r, to: /storage/foo) + signer.link(/public/foo, target: /storage/foo) + } + } + `) + + transaction2 := []byte(` + import Test from 0x1 + transaction { + prepare(signer: AuthAccount) { + let capX = signer.getCapability(/public/foo) + let upCap = capX as Capability<&Test.R> + let downCap = upCap as! Capability + } + } + `) + + runtimeInterface1 := &testRuntimeInterface{ + storage: storage, + log: func(message string) {}, + emitEvent: func(event cadence.Event) error { + return nil + }, + resolveLocation: singleIdentifierLocationResolver(t), + getSigningAccounts: func() ([]Address, error) { + return []Address{[8]byte{0, 0, 0, 0, 0, 0, 0, 1}}, nil + }, + updateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + getAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + } + + nextTransactionLocation := newTransactionLocationGenerator() + + err := rt.ExecuteTransaction( + Script{ + Source: deployTx, + }, + Context{ + Interface: runtimeInterface1, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + err = rt.ExecuteTransaction( + Script{ + Source: transaction1, + }, + Context{ + Interface: runtimeInterface1, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + err = rt.ExecuteTransaction( + Script{ + Source: transaction2, + }, + Context{ + Interface: runtimeInterface1, + Location: nextTransactionLocation(), + }, + ) + + require.ErrorAs(t, err, &interpreter.ForceCastTypeMismatchError{}) +} + +func TestAccountEntitlementCapabilityDictionary(t *testing.T) { + t.Parallel() + + storage := newTestLedger(nil, nil) + rt := newTestInterpreterRuntimeWithAttachments() + accountCodes := map[Location][]byte{} + + deployTx := DeploymentTransaction("Test", []byte(` + pub contract Test { + pub entitlement X + pub entitlement Y + + pub resource R {} + + pub fun createR(): @R { + return <-create R() + } + } + `)) + + transaction1 := []byte(` + import Test from 0x1 + transaction { + prepare(signer: AuthAccount) { + let r <- Test.createR() + signer.save(<-r, to: /storage/foo) + signer.link(/public/foo, target: /storage/foo) + + let r2 <- Test.createR() + signer.save(<-r2, to: /storage/bar) + signer.link(/public/bar, target: /storage/bar) + } + } + `) + + transaction2 := []byte(` + import Test from 0x1 + transaction { + prepare(signer: AuthAccount) { + let capX = signer.getCapability(/public/foo) + let capY = signer.getCapability(/public/bar) + + let dict: {Type: Capability<&Test.R>} = {} + dict[capX.getType()] = capX + dict[capY.getType()] = capY + + let newCapX = dict[capX.getType()]! + let ref = newCapX.borrow()! + let downCast = ref as! auth(Test.X) &Test.R + } + } + `) + + runtimeInterface1 := &testRuntimeInterface{ + storage: storage, + log: func(message string) {}, + emitEvent: func(event cadence.Event) error { + return nil + }, + resolveLocation: singleIdentifierLocationResolver(t), + getSigningAccounts: func() ([]Address, error) { + return []Address{[8]byte{0, 0, 0, 0, 0, 0, 0, 1}}, nil + }, + updateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + getAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + } + + nextTransactionLocation := newTransactionLocationGenerator() + + err := rt.ExecuteTransaction( + Script{ + Source: deployTx, + }, + Context{ + Interface: runtimeInterface1, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + err = rt.ExecuteTransaction( + Script{ + Source: transaction1, + }, + Context{ + Interface: runtimeInterface1, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + err = rt.ExecuteTransaction( + Script{ + Source: transaction2, + }, + Context{ + Interface: runtimeInterface1, + Location: nextTransactionLocation(), + }, + ) + + require.ErrorAs(t, err, &interpreter.ForceCastTypeMismatchError{}) +} diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 3e81c1315c..949d97a121 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -1878,6 +1878,25 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. return ConvertAddress(interpreter, value, locationRange) } + case *sema.CapabilityType: + if !valueType.Equal(unwrappedTargetType) { + if capability, ok := value.(*StorageCapabilityValue); ok && unwrappedTargetType.BorrowType != nil { + targetBorrowType := unwrappedTargetType.BorrowType.(*sema.ReferenceType) + valueBorrowType := capability.BorrowType.(ReferenceStaticType) + borrowType := NewReferenceStaticType( + interpreter, + ConvertSemaAccesstoStaticAuthorization(interpreter, targetBorrowType.Authorization), + valueBorrowType.ReferencedType, + ) + return NewStorageCapabilityValue( + interpreter, + capability.Address, + capability.Path, + borrowType, + ) + } + } + case *sema.ReferenceType: if !valueType.Equal(unwrappedTargetType) { // transferring a reference at runtime does not change its entitlements; this is so that an upcast reference @@ -4222,7 +4241,6 @@ func (interpreter *Interpreter) GetStorageCapabilityFinalTarget( authorization Authorization, err error, ) { - wantedReferenceType := wantedBorrowType seenPaths := map[PathValue]struct{}{} paths := []PathValue{path} @@ -4258,8 +4276,6 @@ func (interpreter *Interpreter) GetStorageCapabilityFinalTarget( return nil, UnauthorizedAccess, nil } - wantedReferenceType = allowedType.(*sema.ReferenceType) - targetPath := value.TargetPath paths = append(paths, targetPath) path = targetPath @@ -4278,7 +4294,7 @@ func (interpreter *Interpreter) GetStorageCapabilityFinalTarget( default: return PathCapabilityTarget(path), - ConvertSemaAccesstoStaticAuthorization(interpreter, wantedReferenceType.Authorization), + ConvertSemaAccesstoStaticAuthorization(interpreter, wantedBorrowType.Authorization), nil } } diff --git a/runtime/tests/interpreter/entitlements_test.go b/runtime/tests/interpreter/entitlements_test.go index fd51cf8837..bfbd68738f 100644 --- a/runtime/tests/interpreter/entitlements_test.go +++ b/runtime/tests/interpreter/entitlements_test.go @@ -577,7 +577,7 @@ func TestInterpretCapabilityEntitlements(t *testing.T) { require.NoError(t, err) }) - t.Run("can borrow with supertype then downcast", func(t *testing.T) { + t.Run("cannot borrow with supertype then downcast", func(t *testing.T) { t.Parallel() address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) @@ -589,19 +589,26 @@ func TestInterpretCapabilityEntitlements(t *testing.T) { entitlement X entitlement Y resource R {} - fun test(): &R { + fun test(): Bool { let r <- create R() account.save(<-r, to: /storage/foo) account.link(/public/foo, target: /storage/foo) let cap = account.getCapability(/public/foo) - return cap.borrow()! as! auth(X, Y) &R + return cap.borrow()! as? auth(X, Y) &R != nil } `, sema.Config{}, ) - _, err := inter.Invoke("test") + value, err := inter.Invoke("test") require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.FalseValue, + value, + ) }) t.Run("can check with supertype", func(t *testing.T) { From cab0aa869818e592a2befb7e9856e18c06339a93 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Fri, 2 Jun 2023 13:30:42 -0400 Subject: [PATCH 02/10] add test for generic capabilities --- runtime/entitlements_test.go | 110 +++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/runtime/entitlements_test.go b/runtime/entitlements_test.go index b81b720304..eed820f96f 100644 --- a/runtime/entitlements_test.go +++ b/runtime/entitlements_test.go @@ -710,3 +710,113 @@ func TestAccountEntitlementCapabilityDictionary(t *testing.T) { require.ErrorAs(t, err, &interpreter.ForceCastTypeMismatchError{}) } + +func TestAccountEntitlementGenericCapabilityDictionary(t *testing.T) { + t.Parallel() + + storage := newTestLedger(nil, nil) + rt := newTestInterpreterRuntimeWithAttachments() + accountCodes := map[Location][]byte{} + + deployTx := DeploymentTransaction("Test", []byte(` + pub contract Test { + pub entitlement X + pub entitlement Y + + pub resource R {} + + pub fun createR(): @R { + return <-create R() + } + } + `)) + + transaction1 := []byte(` + import Test from 0x1 + transaction { + prepare(signer: AuthAccount) { + let r <- Test.createR() + signer.save(<-r, to: /storage/foo) + signer.link(/public/foo, target: /storage/foo) + + let r2 <- Test.createR() + signer.save(<-r2, to: /storage/bar) + signer.link(/public/bar, target: /storage/bar) + } + } + `) + + transaction2 := []byte(` + import Test from 0x1 + transaction { + prepare(signer: AuthAccount) { + let capX = signer.getCapability(/public/foo) + let capY = signer.getCapability(/public/bar) + + let dict: {Type: Capability} = {} + dict[capX.getType()] = capX + dict[capY.getType()] = capY + + let newCapX = dict[capX.getType()]! + let ref = newCapX.borrow()! + let downCast = ref as! auth(Test.X) &Test.R + } + } + `) + + runtimeInterface1 := &testRuntimeInterface{ + storage: storage, + log: func(message string) {}, + emitEvent: func(event cadence.Event) error { + return nil + }, + resolveLocation: singleIdentifierLocationResolver(t), + getSigningAccounts: func() ([]Address, error) { + return []Address{[8]byte{0, 0, 0, 0, 0, 0, 0, 1}}, nil + }, + updateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + getAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + } + + nextTransactionLocation := newTransactionLocationGenerator() + + err := rt.ExecuteTransaction( + Script{ + Source: deployTx, + }, + Context{ + Interface: runtimeInterface1, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + err = rt.ExecuteTransaction( + Script{ + Source: transaction1, + }, + Context{ + Interface: runtimeInterface1, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + err = rt.ExecuteTransaction( + Script{ + Source: transaction2, + }, + Context{ + Interface: runtimeInterface1, + Location: nextTransactionLocation(), + }, + ) + + require.NoError(t, err) +} From d054594d367e132cbb57d2b7d8bd6792241adf3c Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Fri, 2 Jun 2023 15:29:57 -0400 Subject: [PATCH 03/10] add runtime type test --- .../tests/interpreter/entitlements_test.go | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/runtime/tests/interpreter/entitlements_test.go b/runtime/tests/interpreter/entitlements_test.go index bfbd68738f..1a9892aba9 100644 --- a/runtime/tests/interpreter/entitlements_test.go +++ b/runtime/tests/interpreter/entitlements_test.go @@ -611,6 +611,42 @@ func TestInterpretCapabilityEntitlements(t *testing.T) { ) }) + t.Run("upcast runtime type", func(t *testing.T) { + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, + address, + true, + ` + entitlement X + struct S {} + fun test(): Bool { + let s = S() + account.save(s, to: /storage/foo) + account.link(/public/foo, target: /storage/foo) + let cap: Capability = account.getCapability(/public/foo) + let runtimeType = cap.getType() + let upcastCap = cap as Capability<&S> + let upcastRuntimeType = upcastCap.getType() + return runtimeType == upcastRuntimeType + } + `, + sema.Config{}, + ) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.FalseValue, + value, + ) + }) + t.Run("can check with supertype", func(t *testing.T) { t.Parallel() From 3460b2b1c36b38bf3d4b74bcdf9260a203edafd4 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Fri, 2 Jun 2023 15:39:18 -0400 Subject: [PATCH 04/10] test for upcasting runtime type --- .../tests/interpreter/entitlements_test.go | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/runtime/tests/interpreter/entitlements_test.go b/runtime/tests/interpreter/entitlements_test.go index 1a9892aba9..cdf8da431b 100644 --- a/runtime/tests/interpreter/entitlements_test.go +++ b/runtime/tests/interpreter/entitlements_test.go @@ -611,7 +611,7 @@ func TestInterpretCapabilityEntitlements(t *testing.T) { ) }) - t.Run("upcast runtime type", func(t *testing.T) { + t.Run("upcast runtime entitlements", func(t *testing.T) { t.Parallel() address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) @@ -647,6 +647,41 @@ func TestInterpretCapabilityEntitlements(t *testing.T) { ) }) + t.Run("upcast runtime type", func(t *testing.T) { + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, + address, + true, + ` + struct S {} + fun test(): Bool { + let s = S() + account.save(s, to: /storage/foo) + account.link<&S>(/public/foo, target: /storage/foo) + let cap: Capability<&S> = account.getCapability<&S>(/public/foo) + let runtimeType = cap.getType() + let upcastCap = cap as Capability<&AnyStruct> + let upcastRuntimeType = upcastCap.getType() + return runtimeType == upcastRuntimeType + } + `, + sema.Config{}, + ) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.TrueValue, + value, + ) + }) + t.Run("can check with supertype", func(t *testing.T) { t.Parallel() From 4a9d6464d6880e61101af1e9e566ce94ce0c0560 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Mon, 12 Jun 2023 15:08:06 -0400 Subject: [PATCH 05/10] convert optional values --- runtime/interpreter/interpreter.go | 35 ++- .../tests/interpreter/entitlements_test.go | 252 ++++++++++++++++++ .../tests/interpreter/memory_metering_test.go | 4 +- 3 files changed, 282 insertions(+), 9 deletions(-) diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 5b3ce7d922..91b71ea6ed 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -1781,12 +1781,18 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. return value } - if _, valueIsOptional := valueType.(*sema.OptionalType); valueIsOptional { - return value - } - unwrappedTargetType := sema.UnwrapOptionalType(targetType) + if optionalValueType, valueIsOptional := valueType.(*sema.OptionalType); valueIsOptional { + switch value := value.(type) { + case NilValue: + return value + case *SomeValue: + innerValue := interpreter.convert(value.value, optionalValueType.Type, unwrappedTargetType, locationRange) + return NewSomeValueNonCopying(interpreter, innerValue) + } + } + switch unwrappedTargetType { case sema.IntType: if !valueType.Equal(unwrappedTargetType) { @@ -1906,9 +1912,11 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. } case *sema.CapabilityType: - if !valueType.Equal(unwrappedTargetType) { - if capability, ok := value.(*PathCapabilityValue); ok && unwrappedTargetType.BorrowType != nil { - targetBorrowType := unwrappedTargetType.BorrowType.(*sema.ReferenceType) + if !valueType.Equal(unwrappedTargetType) && unwrappedTargetType.BorrowType != nil { + targetBorrowType := unwrappedTargetType.BorrowType.(*sema.ReferenceType) + + switch capability := value.(type) { + case *PathCapabilityValue: valueBorrowType := capability.BorrowType.(ReferenceStaticType) borrowType := NewReferenceStaticType( interpreter, @@ -1921,6 +1929,19 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. capability.Path, borrowType, ) + case *IDCapabilityValue: + valueBorrowType := capability.BorrowType.(ReferenceStaticType) + borrowType := NewReferenceStaticType( + interpreter, + ConvertSemaAccesstoStaticAuthorization(interpreter, targetBorrowType.Authorization), + valueBorrowType.ReferencedType, + ) + return NewIDCapabilityValue( + interpreter, + capability.ID, + capability.Address, + borrowType, + ) } } diff --git a/runtime/tests/interpreter/entitlements_test.go b/runtime/tests/interpreter/entitlements_test.go index 9055b7ed41..5f94f52d30 100644 --- a/runtime/tests/interpreter/entitlements_test.go +++ b/runtime/tests/interpreter/entitlements_test.go @@ -577,6 +577,258 @@ func TestInterpretEntitledReferenceCasting(t *testing.T) { value, ) }) + + t.Run("capability downcast", func(t *testing.T) { + + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, + address, + true, + ` + entitlement X + entitlement Y + + fun test(): Bool { + account.save(3, to: /storage/foo) + let capX = account.getCapability(/public/foo) + let upCap = capX as Capability + return upCap as? Capability == nil + } + `, + sema.Config{}) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.TrueValue, + value, + ) + }) + + t.Run("unparameterized capability downcast", func(t *testing.T) { + + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, + address, + true, + ` + entitlement X + + fun test(): Bool { + account.save(3, to: /storage/foo) + let capX = account.getCapability(/public/foo) + let upCap = capX as Capability + return upCap as? Capability == nil + } + `, + sema.Config{}) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.FalseValue, + value, + ) + }) + + t.Run("ref downcast", func(t *testing.T) { + + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, + address, + true, + ` + entitlement X + + fun test(): Bool { + let arr: auth(X) &Int = &1 + let upArr = arr as &Int + return upArr as? auth(X) &Int == nil + } + `, + sema.Config{}) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.TrueValue, + value, + ) + }) + + t.Run("optional ref downcast", func(t *testing.T) { + + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, + address, + true, + ` + entitlement X + + fun test(): Bool { + let arr: auth(X) &Int? = &1 + let upArr = arr as &Int? + return upArr as? auth(X) &Int? == nil + } + `, + sema.Config{}) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.TrueValue, + value, + ) + }) + + t.Run("ref array downcast", func(t *testing.T) { + + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, + address, + true, + ` + entitlement X + + fun test(): Bool { + let arr: [auth(X) &Int] = [&1, &2] + let upArr = arr as [&Int] + return upArr as? [auth(X) &Int] == nil + } + `, + sema.Config{}) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.TrueValue, + value, + ) + }) + + t.Run("ref array element downcast", func(t *testing.T) { + + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, + address, + true, + ` + entitlement X + + fun test(): Bool { + let arr: [auth(X) &Int] = [&1, &2] + let upArr = arr as [&Int] + return upArr[0] as? auth(X) &Int == nil + } + `, + sema.Config{}) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.TrueValue, + value, + ) + }) + + t.Run("ref dict downcast", func(t *testing.T) { + + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, + address, + true, + ` + entitlement X + + fun test(): Bool { + let dict: {String: auth(X) &Int} = {"foo": &3} + let upDict = dict as {String: &Int} + return upDict as? {String: auth(X) &Int} == nil + } + `, + sema.Config{}) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.TrueValue, + value, + ) + }) + + t.Run("ref dict element downcast forced", func(t *testing.T) { + + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, + address, + true, + ` + entitlement X + + fun test(): Bool { + let dict: {String: auth(X) &Int} = {"foo": &3} + let upDict = dict as {String: &Int} + return upDict["foo"]! as? auth(X) &Int == nil + } + `, + sema.Config{}) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.TrueValue, + value, + ) + }) + } func TestInterpretCapabilityEntitlements(t *testing.T) { diff --git a/runtime/tests/interpreter/memory_metering_test.go b/runtime/tests/interpreter/memory_metering_test.go index 1dcb7959a1..584eb27964 100644 --- a/runtime/tests/interpreter/memory_metering_test.go +++ b/runtime/tests/interpreter/memory_metering_test.go @@ -1296,8 +1296,8 @@ func TestInterpretOptionalValueMetering(t *testing.T) { _, err := inter.Invoke("main") require.NoError(t, err) - // 2 for `z` - assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindOptionalValue)) + // 3 for `z` + assert.Equal(t, uint64(3), meter.getMemory(common.MemoryKindOptionalValue)) assert.Equal(t, uint64(14), meter.getMemory(common.MemoryKindPrimitiveStaticType)) assert.Equal(t, uint64(3), meter.getMemory(common.MemoryKindDictionaryStaticType)) From 72c89fe971b9269a136628541b56361f17f5fe2b Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Mon, 12 Jun 2023 16:51:34 -0400 Subject: [PATCH 06/10] convert array and dictionary values --- runtime/interpreter/interpreter.go | 145 ++++++++++++++++-- .../tests/interpreter/entitlements_test.go | 62 ++++++++ .../tests/interpreter/memory_metering_test.go | 4 +- 3 files changed, 197 insertions(+), 14 deletions(-) diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 91b71ea6ed..86f54ef350 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -1776,6 +1776,79 @@ func (interpreter *Interpreter) ConvertAndBox( return interpreter.BoxOptional(locationRange, value, targetType) } +func (interpreter *Interpreter) convertStaticType( + valueStaticType StaticType, + targetSemaType sema.Type, +) StaticType { + switch valueStaticType := valueStaticType.(type) { + case ReferenceStaticType: + if targetReferenceType, isReferenceType := targetSemaType.(*sema.ReferenceType); isReferenceType { + return NewReferenceStaticType( + interpreter, + ConvertSemaAccesstoStaticAuthorization(interpreter, targetReferenceType.Authorization), + valueStaticType.ReferencedType, + ) + } + case OptionalStaticType: + if targetOptionalType, isOptionalType := targetSemaType.(*sema.OptionalType); isOptionalType { + return NewOptionalStaticType( + interpreter, + interpreter.convertStaticType( + valueStaticType.Type, + targetOptionalType.Type, + ), + ) + } + case DictionaryStaticType: + if targetDictionaryType, isDictionaryType := targetSemaType.(*sema.DictionaryType); isDictionaryType { + return NewDictionaryStaticType( + interpreter, + interpreter.convertStaticType( + valueStaticType.KeyType, + targetDictionaryType.KeyType, + ), + interpreter.convertStaticType( + valueStaticType.ValueType, + targetDictionaryType.ValueType, + ), + ) + } + case VariableSizedStaticType: + if targetArrayType, isArrayType := targetSemaType.(*sema.VariableSizedType); isArrayType { + return NewVariableSizedStaticType( + interpreter, + interpreter.convertStaticType( + valueStaticType.Type, + targetArrayType.Type, + ), + ) + } + case ConstantSizedStaticType: + if targetArrayType, isArrayType := targetSemaType.(*sema.ConstantSizedType); isArrayType { + return NewConstantSizedStaticType( + interpreter, + interpreter.convertStaticType( + valueStaticType.Type, + targetArrayType.Type, + ), + valueStaticType.Size, + ) + } + + case CapabilityStaticType: + if targetCapabilityType, isCapabilityType := targetSemaType.(*sema.CapabilityType); isCapabilityType { + return NewCapabilityStaticType( + interpreter, + interpreter.convertStaticType( + valueStaticType.BorrowType, + targetCapabilityType.BorrowType, + ), + ) + } + } + return valueStaticType +} + func (interpreter *Interpreter) convert(value Value, valueType, targetType sema.Type, locationRange LocationRange) Value { if valueType == nil { return value @@ -1783,13 +1856,17 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. unwrappedTargetType := sema.UnwrapOptionalType(targetType) + // if the value is optional, convert the inner value to the unwrapped target type if optionalValueType, valueIsOptional := valueType.(*sema.OptionalType); valueIsOptional { switch value := value.(type) { case NilValue: return value case *SomeValue: - innerValue := interpreter.convert(value.value, optionalValueType.Type, unwrappedTargetType, locationRange) - return NewSomeValueNonCopying(interpreter, innerValue) + if !optionalValueType.Type.Equal(unwrappedTargetType) { + innerValue := interpreter.convert(value.value, optionalValueType.Type, unwrappedTargetType, locationRange) + return NewSomeValueNonCopying(interpreter, innerValue) + } + return value } } @@ -1911,6 +1988,57 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. return ConvertAddress(interpreter, value, locationRange) } + case *sema.ConstantSizedType, *sema.VariableSizedType: + if arrayValue, isArray := value.(*ArrayValue); isArray && !valueType.Equal(unwrappedTargetType) { + + arrayStaticType := interpreter.convertStaticType(arrayValue.StaticType(interpreter), unwrappedTargetType).(ArrayStaticType) + targetElementType := interpreter.MustConvertStaticToSemaType(arrayStaticType.ElementType()) + + values := make([]Value, 0, arrayValue.Count()) + + arrayValue.Iterate(interpreter, func(v Value) bool { + valueType := interpreter.MustConvertStaticToSemaType(v.StaticType(interpreter)) + values = append(values, interpreter.convert(v, valueType, targetElementType, locationRange)) + return true + }) + + return NewArrayValue( + interpreter, + locationRange, + arrayStaticType, + arrayValue.GetOwner(), + values..., + ) + } + + case *sema.DictionaryType: + if dictValue, isDict := value.(*DictionaryValue); isDict && !valueType.Equal(unwrappedTargetType) { + + dictStaticType := interpreter.convertStaticType(dictValue.StaticType(interpreter), unwrappedTargetType).(DictionaryStaticType) + targetKeyType := interpreter.MustConvertStaticToSemaType(dictStaticType.KeyType) + targetValueType := interpreter.MustConvertStaticToSemaType(dictStaticType.ValueType) + + values := make([]Value, 0, dictValue.Count()*2) + + dictValue.Iterate(interpreter, func(key, value Value) bool { + keyType := interpreter.MustConvertStaticToSemaType(key.StaticType(interpreter)) + valueType := interpreter.MustConvertStaticToSemaType(value.StaticType(interpreter)) + + convertedKey := interpreter.convert(key, keyType, targetKeyType, locationRange) + convertedValue := interpreter.convert(value, valueType, targetValueType, locationRange) + + values = append(values, convertedKey, convertedValue) + return true + }) + + return NewDictionaryValue( + interpreter, + locationRange, + dictStaticType, + values..., + ) + } + case *sema.CapabilityType: if !valueType.Equal(unwrappedTargetType) && unwrappedTargetType.BorrowType != nil { targetBorrowType := unwrappedTargetType.BorrowType.(*sema.ReferenceType) @@ -1918,11 +2046,7 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. switch capability := value.(type) { case *PathCapabilityValue: valueBorrowType := capability.BorrowType.(ReferenceStaticType) - borrowType := NewReferenceStaticType( - interpreter, - ConvertSemaAccesstoStaticAuthorization(interpreter, targetBorrowType.Authorization), - valueBorrowType.ReferencedType, - ) + borrowType := interpreter.convertStaticType(valueBorrowType, targetBorrowType) return NewPathCapabilityValue( interpreter, capability.Address, @@ -1931,11 +2055,7 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. ) case *IDCapabilityValue: valueBorrowType := capability.BorrowType.(ReferenceStaticType) - borrowType := NewReferenceStaticType( - interpreter, - ConvertSemaAccesstoStaticAuthorization(interpreter, targetBorrowType.Authorization), - valueBorrowType.ReferencedType, - ) + borrowType := interpreter.convertStaticType(valueBorrowType, targetBorrowType) return NewIDCapabilityValue( interpreter, capability.ID, @@ -1949,6 +2069,7 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. if !valueType.Equal(unwrappedTargetType) { // transferring a reference at runtime does not change its entitlements; this is so that an upcast reference // can later be downcast back to its original entitlement set + switch ref := value.(type) { case *EphemeralReferenceValue: return NewEphemeralReferenceValue( diff --git a/runtime/tests/interpreter/entitlements_test.go b/runtime/tests/interpreter/entitlements_test.go index 5f94f52d30..9d81158e06 100644 --- a/runtime/tests/interpreter/entitlements_test.go +++ b/runtime/tests/interpreter/entitlements_test.go @@ -736,6 +736,37 @@ func TestInterpretEntitledReferenceCasting(t *testing.T) { ) }) + t.Run("ref constant array downcast", func(t *testing.T) { + + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, + address, + true, + ` + entitlement X + + fun test(): Bool { + let arr: [auth(X) ∬ 2] = [&1, &2] + let upArr = arr as [∬ 2] + return upArr as? [auth(X) ∬ 2] == nil + } + `, + sema.Config{}) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.TrueValue, + value, + ) + }) + t.Run("ref array element downcast", func(t *testing.T) { t.Parallel() @@ -767,6 +798,37 @@ func TestInterpretEntitledReferenceCasting(t *testing.T) { ) }) + t.Run("ref constant array element downcast", func(t *testing.T) { + + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, + address, + true, + ` + entitlement X + + fun test(): Bool { + let arr: [auth(X) ∬ 2] = [&1, &2] + let upArr = arr as [∬ 2] + return upArr[0] as? auth(X) &Int == nil + } + `, + sema.Config{}) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.TrueValue, + value, + ) + }) + t.Run("ref dict downcast", func(t *testing.T) { t.Parallel() diff --git a/runtime/tests/interpreter/memory_metering_test.go b/runtime/tests/interpreter/memory_metering_test.go index 584eb27964..1dcb7959a1 100644 --- a/runtime/tests/interpreter/memory_metering_test.go +++ b/runtime/tests/interpreter/memory_metering_test.go @@ -1296,8 +1296,8 @@ func TestInterpretOptionalValueMetering(t *testing.T) { _, err := inter.Invoke("main") require.NoError(t, err) - // 3 for `z` - assert.Equal(t, uint64(3), meter.getMemory(common.MemoryKindOptionalValue)) + // 2 for `z` + assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindOptionalValue)) assert.Equal(t, uint64(14), meter.getMemory(common.MemoryKindPrimitiveStaticType)) assert.Equal(t, uint64(3), meter.getMemory(common.MemoryKindDictionaryStaticType)) From dca96653c172f2322cd8c499f50181eb71a2d602 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 20 Jun 2023 12:05:13 -0400 Subject: [PATCH 07/10] respond to review --- runtime/interpreter/interpreter.go | 75 +++++++++++++------ runtime/interpreter/value.go | 62 +++++++++++++++ .../tests/interpreter/entitlements_test.go | 15 +++- 3 files changed, 127 insertions(+), 25 deletions(-) diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 86f54ef350..e2485770b8 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -1994,20 +1994,31 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. arrayStaticType := interpreter.convertStaticType(arrayValue.StaticType(interpreter), unwrappedTargetType).(ArrayStaticType) targetElementType := interpreter.MustConvertStaticToSemaType(arrayStaticType.ElementType()) - values := make([]Value, 0, arrayValue.Count()) + array := arrayValue.array - arrayValue.Iterate(interpreter, func(v Value) bool { - valueType := interpreter.MustConvertStaticToSemaType(v.StaticType(interpreter)) - values = append(values, interpreter.convert(v, valueType, targetElementType, locationRange)) - return true - }) + iterator, err := array.Iterator() + if err != nil { + panic(errors.NewExternalError(err)) + } - return NewArrayValue( + return NewArrayValueWithIterator( interpreter, - locationRange, arrayStaticType, arrayValue.GetOwner(), - values..., + array.Count(), + func() Value { + element, err := iterator.Next() + if err != nil { + panic(errors.NewExternalError(err)) + } + if element == nil { + return nil + } + + value := MustConvertStoredValue(interpreter, element) + valueType := interpreter.MustConvertStaticToSemaType(value.StaticType(interpreter)) + return interpreter.convert(value, valueType, targetElementType, locationRange) + }, ) } @@ -2018,24 +2029,41 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. targetKeyType := interpreter.MustConvertStaticToSemaType(dictStaticType.KeyType) targetValueType := interpreter.MustConvertStaticToSemaType(dictStaticType.ValueType) - values := make([]Value, 0, dictValue.Count()*2) - - dictValue.Iterate(interpreter, func(key, value Value) bool { - keyType := interpreter.MustConvertStaticToSemaType(key.StaticType(interpreter)) - valueType := interpreter.MustConvertStaticToSemaType(value.StaticType(interpreter)) + dictionary := dictValue.dictionary - convertedKey := interpreter.convert(key, keyType, targetKeyType, locationRange) - convertedValue := interpreter.convert(value, valueType, targetValueType, locationRange) - - values = append(values, convertedKey, convertedValue) - return true - }) + iterator, err := dictionary.Iterator() + if err != nil { + panic(errors.NewExternalError(err)) + } - return NewDictionaryValue( + return newDictionaryValueWithIterator( interpreter, locationRange, dictStaticType, - values..., + dictionary.Count(), + dictionary.Seed(), + common.Address(dictionary.Address()), + func() (Value, Value) { + k, v, err := iterator.Next() + + if err != nil { + panic(errors.NewExternalError(err)) + } + if k == nil || v == nil { + return nil, nil + } + + key := MustConvertStoredValue(interpreter, k) + value := MustConvertStoredValue(interpreter, v) + + keyType := interpreter.MustConvertStaticToSemaType(key.StaticType(interpreter)) + valueType := interpreter.MustConvertStaticToSemaType(value.StaticType(interpreter)) + + convertedKey := interpreter.convert(key, keyType, targetKeyType, locationRange) + convertedValue := interpreter.convert(value, valueType, targetValueType, locationRange) + + return convertedKey, convertedValue + }, ) } @@ -2062,6 +2090,9 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. capability.Address, borrowType, ) + default: + // unsupported capability value + panic(errors.NewUnreachableError()) } } diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 311c81efaf..c50eb8bbb3 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -16191,6 +16191,68 @@ func newDictionaryValueFromOrderedMap( } } +func newDictionaryValueWithIterator( + interpreter *Interpreter, + locationRange LocationRange, + staticType DictionaryStaticType, + count uint64, + seed uint64, + address common.Address, + values func() (Value, Value), +) *DictionaryValue { + interpreter.ReportComputation(common.ComputationKindCreateDictionaryValue, 1) + + var v *DictionaryValue + + config := interpreter.SharedState.Config + + if config.TracingEnabled { + startTime := time.Now() + + defer func() { + // NOTE: in defer, as v is only initialized at the end of the function + // if there was no error during construction + if v == nil { + return + } + + typeInfo := v.Type.String() + count := v.Count() + + interpreter.reportDictionaryValueConstructTrace( + typeInfo, + count, + time.Since(startTime), + ) + }() + } + + constructor := func() *atree.OrderedMap { + orderedMap, err := atree.NewMapFromBatchData( + config.Storage, + atree.Address(address), + atree.NewDefaultDigesterBuilder(), + staticType, + newValueComparator(interpreter, locationRange), + newHashInputProvider(interpreter, locationRange), + seed, + func() (atree.Value, atree.Value, error) { + key, value := values() + return key, value, nil + }, + ) + if err != nil { + panic(errors.NewExternalError(err)) + } + return orderedMap + } + + // values are added to the dictionary after creation, not here + v = newDictionaryValueFromConstructor(interpreter, staticType, count, constructor) + + return v +} + func newDictionaryValueFromConstructor( gauge common.MemoryGauge, staticType DictionaryStaticType, diff --git a/runtime/tests/interpreter/entitlements_test.go b/runtime/tests/interpreter/entitlements_test.go index 9d81158e06..c881dcad97 100644 --- a/runtime/tests/interpreter/entitlements_test.go +++ b/runtime/tests/interpreter/entitlements_test.go @@ -623,11 +623,11 @@ func TestInterpretEntitledReferenceCasting(t *testing.T) { ` entitlement X - fun test(): Bool { + fun test(): Capability { account.save(3, to: /storage/foo) let capX = account.getCapability(/public/foo) let upCap = capX as Capability - return upCap as? Capability == nil + return (upCap as? Capability)! } `, sema.Config{}) @@ -638,7 +638,16 @@ func TestInterpretEntitledReferenceCasting(t *testing.T) { AssertValuesEqual( t, inter, - interpreter.FalseValue, + interpreter.NewPathCapabilityValue( + nil, + address, + interpreter.NewPathValue(nil, common.PathDomainPublic, "foo"), + interpreter.NewReferenceStaticType( + nil, + interpreter.NewEntitlementSetAuthorization(nil, []common.TypeID{"S.test.X"}, sema.Conjunction), + interpreter.PrimitiveStaticTypeInt, + ), + ), value, ) }) From 07aece407e27ad5ab7713e4f28a7b6bd7a6896fd Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 20 Jun 2023 14:20:19 -0400 Subject: [PATCH 08/10] add comment --- runtime/interpreter/interpreter.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index e2485770b8..49b7457ea4 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -1776,6 +1776,10 @@ func (interpreter *Interpreter) ConvertAndBox( return interpreter.BoxOptional(locationRange, value, targetType) } +// Produces the `valueStaticType` argument into a new static type that conforms +// to the specification of the `targetSemaType`. At the moment, this means that the +// authorization of any reference types in `valueStaticType` are changed to match the +// authorization of any equivalently-positioned reference types in `targetSemaType`. func (interpreter *Interpreter) convertStaticType( valueStaticType StaticType, targetSemaType sema.Type, From 1092e7f7efb230ac86e0c4cac3f6a0c75f64a8f5 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 21 Jun 2023 16:44:43 -0400 Subject: [PATCH 09/10] Update runtime/interpreter/interpreter.go Co-authored-by: Supun Setunga --- runtime/interpreter/interpreter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 49b7457ea4..d153d3a416 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -1992,7 +1992,7 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. return ConvertAddress(interpreter, value, locationRange) } - case *sema.ConstantSizedType, *sema.VariableSizedType: + case sema.ArrayType: if arrayValue, isArray := value.(*ArrayValue); isArray && !valueType.Equal(unwrappedTargetType) { arrayStaticType := interpreter.convertStaticType(arrayValue.StaticType(interpreter), unwrappedTargetType).(ArrayStaticType) From 9bae858d79ee52359fd88c7ce037584f78f532bd Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 21 Jun 2023 16:49:42 -0400 Subject: [PATCH 10/10] only produce new array and dictionary values when references within have changed --- runtime/interpreter/interpreter.go | 16 ++++++++-- .../tests/interpreter/entitlements_test.go | 31 +++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 49b7457ea4..e3b4dcaa1a 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -1995,7 +1995,13 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. case *sema.ConstantSizedType, *sema.VariableSizedType: if arrayValue, isArray := value.(*ArrayValue); isArray && !valueType.Equal(unwrappedTargetType) { - arrayStaticType := interpreter.convertStaticType(arrayValue.StaticType(interpreter), unwrappedTargetType).(ArrayStaticType) + oldArrayStaticType := arrayValue.StaticType(interpreter) + arrayStaticType := interpreter.convertStaticType(oldArrayStaticType, unwrappedTargetType).(ArrayStaticType) + + if oldArrayStaticType.Equal(arrayStaticType) { + return value + } + targetElementType := interpreter.MustConvertStaticToSemaType(arrayStaticType.ElementType()) array := arrayValue.array @@ -2029,7 +2035,13 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. case *sema.DictionaryType: if dictValue, isDict := value.(*DictionaryValue); isDict && !valueType.Equal(unwrappedTargetType) { - dictStaticType := interpreter.convertStaticType(dictValue.StaticType(interpreter), unwrappedTargetType).(DictionaryStaticType) + oldDictStaticType := dictValue.StaticType(interpreter) + dictStaticType := interpreter.convertStaticType(oldDictStaticType, unwrappedTargetType).(DictionaryStaticType) + + if oldDictStaticType.Equal(dictStaticType) { + return value + } + targetKeyType := interpreter.MustConvertStaticToSemaType(dictStaticType.KeyType) targetValueType := interpreter.MustConvertStaticToSemaType(dictStaticType.ValueType) diff --git a/runtime/tests/interpreter/entitlements_test.go b/runtime/tests/interpreter/entitlements_test.go index c881dcad97..313fdfd91b 100644 --- a/runtime/tests/interpreter/entitlements_test.go +++ b/runtime/tests/interpreter/entitlements_test.go @@ -776,6 +776,37 @@ func TestInterpretEntitledReferenceCasting(t *testing.T) { ) }) + t.Run("ref constant array downcast no change", func(t *testing.T) { + + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + inter, _ := testAccount(t, + address, + true, + ` + entitlement X + + fun test(): Bool { + let arr: [auth(X) ∬ 2] = [&1, &2] + let upArr = arr as [auth(X) ∬ 2] + return upArr as? [auth(X) ∬ 2] == nil + } + `, + sema.Config{}) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.FalseValue, + value, + ) + }) + t.Run("ref array element downcast", func(t *testing.T) { t.Parallel()