Skip to content

Commit

Permalink
Merge pull request #3266 from onflow/bastian/improve-supported-entitl…
Browse files Browse the repository at this point in the history
…ements
  • Loading branch information
turbolent authored Apr 24, 2024
2 parents 1ea9852 + 8f666bb commit 93a063a
Show file tree
Hide file tree
Showing 15 changed files with 790 additions and 101 deletions.
2 changes: 1 addition & 1 deletion migrations/entitlements/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (m EntitlementsMigration) ConvertToEntitledType(

default:
supportedEntitlements := entitlementSupportingType.SupportedEntitlements()
newAccess := sema.NewAccessFromEntitlementSet(supportedEntitlements, sema.Conjunction)
newAccess := supportedEntitlements.Access()
auth = interpreter.ConvertSemaAccessToStaticAuthorization(inter, newAccess)
returnNew = true
}
Expand Down
2 changes: 1 addition & 1 deletion migrations/entitlements/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ func TestConvertToEntitledType(t *testing.T) {
},
{
Input: sema.NewReferenceType(nil, sema.UnauthorizedAccess, compositeResourceWithEOrF),
Output: sema.NewReferenceType(nil, eAndFAccess, compositeResourceWithEOrF),
Output: sema.NewReferenceType(nil, eOrFAccess, compositeResourceWithEOrF),
Name: "composite E or F",
},
{
Expand Down
18 changes: 5 additions & 13 deletions runtime/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -817,14 +817,8 @@ func (interpreter *Interpreter) resultValue(returnValue Value, returnType sema.T
auth := UnauthorizedAccess
// reference is authorized to the entire resource, since it is only accessible in a function where a resource value is owned
if entitlementSupportingType, ok := ty.(sema.EntitlementSupportingType); ok {
supportedEntitlements := entitlementSupportingType.SupportedEntitlements()
if supportedEntitlements != nil && supportedEntitlements.Len() > 0 {
access := sema.EntitlementSetAccess{
SetKind: sema.Conjunction,
Entitlements: supportedEntitlements,
}
auth = ConvertSemaAccessToStaticAuthorization(interpreter, access)
}
access := entitlementSupportingType.SupportedEntitlements().Access()
auth = ConvertSemaAccessToStaticAuthorization(interpreter, access)
}
return auth
}
Expand Down Expand Up @@ -1038,7 +1032,7 @@ func (interpreter *Interpreter) evaluateDefaultDestroyEvent(
panic(errors.NewUnreachableError())
}
supportedEntitlements := entitlementSupportingType.SupportedEntitlements()
access := sema.NewAccessFromEntitlementSet(supportedEntitlements, sema.Conjunction)
access := supportedEntitlements.Access()
base, self = attachmentBaseAndSelfValues(
declarationInterpreter,
access,
Expand Down Expand Up @@ -1393,10 +1387,8 @@ func (declarationInterpreter *Interpreter) declareNonEnumCompositeValue(
// Self's type in the constructor is fully entitled, since
// the constructor can only be called when in possession of the base resource

auth := ConvertSemaAccessToStaticAuthorization(
interpreter,
sema.NewAccessFromEntitlementSet(attachmentType.SupportedEntitlements(), sema.Conjunction),
)
access := attachmentType.SupportedEntitlements().Access()
auth := ConvertSemaAccessToStaticAuthorization(interpreter, access)

self = NewEphemeralReferenceValue(interpreter, auth, value, attachmentType, locationRange)

Expand Down
2 changes: 1 addition & 1 deletion runtime/interpreter/interpreter_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,7 @@ func (interpreter *Interpreter) VisitAttachExpression(attachExpression *ast.Atta
// within the constructor, the attachment's base and self references should be fully entitled,
// as the constructor of the attachment is only callable by the owner of the base
baseType := interpreter.MustSemaTypeOfValue(base).(sema.EntitlementSupportingType)
baseAccess := sema.NewAccessFromEntitlementSet(baseType.SupportedEntitlements(), sema.Conjunction)
baseAccess := baseType.SupportedEntitlements().Access()
auth := ConvertSemaAccessToStaticAuthorization(interpreter, baseAccess)

attachmentType := interpreter.Program.Elaboration.AttachTypes(attachExpression)
Expand Down
4 changes: 2 additions & 2 deletions runtime/interpreter/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -18370,10 +18370,10 @@ func (v *CompositeValue) GetTypeKey(
locationRange LocationRange,
ty sema.Type,
) Value {
var access sema.Access = sema.UnauthorizedAccess
access := sema.UnauthorizedAccess
attachmentTyp, isAttachmentType := ty.(*sema.CompositeType)
if isAttachmentType {
access = sema.NewAccessFromEntitlementSet(attachmentTyp.SupportedEntitlements(), sema.Conjunction)
access = attachmentTyp.SupportedEntitlements().Access()
}
return v.getTypeKey(interpreter, locationRange, ty, access)
}
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func NewEntitlementSetAccess(
}
}

func NewAccessFromEntitlementSet(
func NewAccessFromEntitlementOrderedSet(
set *EntitlementOrderedSet,
setKind EntitlementSetKind,
) Access {
Expand Down
31 changes: 23 additions & 8 deletions runtime/sema/check_composite_declaration.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,35 @@ func (checker *Checker) checkAttachmentMembersAccess(attachmentType *CompositeTy
var supportedBaseEntitlements *EntitlementOrderedSet
baseType := attachmentType.GetBaseType()
if base, ok := attachmentType.GetBaseType().(EntitlementSupportingType); ok {
supportedBaseEntitlements = base.SupportedEntitlements()
// TODO:
access := base.SupportedEntitlements().Access()
if access, ok := access.(EntitlementSetAccess); ok {
supportedBaseEntitlements = access.Entitlements
}
}
if supportedBaseEntitlements == nil {
supportedBaseEntitlements = &orderedmap.OrderedMap[*EntitlementType, struct{}]{}
}

attachmentType.EffectiveInterfaceConformanceSet().ForEach(func(intf *InterfaceType) {
intf.Members.Foreach(func(_ string, member *Member) {
checker.checkAttachmentMemberAccess(attachmentType, member, baseType, supportedBaseEntitlements)
attachmentType.EffectiveInterfaceConformanceSet().
ForEach(func(interfaceType *InterfaceType) {
interfaceType.Members.Foreach(func(_ string, member *Member) {
checker.checkAttachmentMemberAccess(
attachmentType,
member,
baseType,
supportedBaseEntitlements,
)
})
})
})

attachmentType.Members.Foreach(func(_ string, member *Member) {
checker.checkAttachmentMemberAccess(attachmentType, member, baseType, supportedBaseEntitlements)
checker.checkAttachmentMemberAccess(
attachmentType,
member,
baseType,
supportedBaseEntitlements,
)
})

}
Expand Down Expand Up @@ -2081,7 +2096,7 @@ func (checker *Checker) checkDefaultDestroyEventParam(

// make `self` and `base` available when checking default arguments so the fields of the composite are available
// as this event is emitted when the resource is destroyed, these values should be fully entitled
fullyEntitledAccess := NewAccessFromEntitlementSet(containerType.SupportedEntitlements(), Conjunction)
fullyEntitledAccess := containerType.SupportedEntitlements().Access()

checker.declareSelfValue(
fullyEntitledAccess,
Expand Down Expand Up @@ -2224,7 +2239,7 @@ func (checker *Checker) checkSpecialFunction(
defer checker.leaveValueScope(specialFunction.EndPosition, checkResourceLoss)

// initializers and destructors are considered fully entitled to their container type
fnAccess := NewAccessFromEntitlementSet(containerType.SupportedEntitlements(), Conjunction)
fnAccess := containerType.SupportedEntitlements().Access()

checker.declareSelfValue(fnAccess, containerType, containerDocString)

Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ func (checker *Checker) visitWithPostConditions(postConditions *ast.Conditions,
// here the `result` value in the `post` block will have type `auth(E, X, Y) &R`
if entitlementSupportingType, ok := innerType.(EntitlementSupportingType); ok {
supportedEntitlements := entitlementSupportingType.SupportedEntitlements()
auth = NewAccessFromEntitlementSet(supportedEntitlements, Conjunction)
auth = supportedEntitlements.Access()
}

resultType = &ReferenceType{
Expand Down
9 changes: 3 additions & 6 deletions runtime/sema/check_member_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,9 @@ func allSupportedEntitlements(typ Type, isInnerType bool) Access {
return allSupportedEntitlements(typ.ReturnTypeAnnotation.Type, true)
}
case EntitlementSupportingType:
supportedEntitlements := typ.SupportedEntitlements()
if supportedEntitlements != nil && supportedEntitlements.Len() > 0 {
return EntitlementSetAccess{
SetKind: Conjunction,
Entitlements: supportedEntitlements,
}
access := typ.SupportedEntitlements().Access()
if access != UnauthorizedAccess {
return access
}
}

Expand Down
191 changes: 191 additions & 0 deletions runtime/sema/entitlementset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Cadence - The resource-oriented smart contract programming language
*
* Copyright Dapper Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sema

import (
"sort"
"strings"

"github.com/onflow/cadence/runtime/common/orderedmap"
)

func disjunctionKey(disjunction *EntitlementOrderedSet) string {
// Gather type IDs, sorted
var typeIDs []string
disjunction.Foreach(func(entitlementType *EntitlementType, _ struct{}) {
typeIDs = append(typeIDs, string(entitlementType.ID()))
})
sort.Strings(typeIDs)

// Join type IDs
var sb strings.Builder
for index, typeID := range typeIDs {
if index > 0 {
sb.WriteByte('|')
}
sb.WriteString(typeID)
}
return sb.String()
}

// DisjunctionOrderedSet is a set of entitlement disjunctions, keyed by disjunctionKey
type DisjunctionOrderedSet = orderedmap.OrderedMap[string, *EntitlementOrderedSet]

// EntitlementSet is a set (conjunction) of entitlements and entitlement disjunctions.
// e.g. {entitlements: A, B; disjunctions: (C | D), (E | F)}
type EntitlementSet struct {
// Entitlements is a set of entitlements
Entitlements *EntitlementOrderedSet
// Disjunctions is a set of entitlement disjunctions, keyed by disjunctionKey
Disjunctions *DisjunctionOrderedSet
}

// Add adds an entitlement to the set.
//
// NOTE: The resulting set is potentially not minimal:
// If the set contains a disjunction that contains the entitlement,
// then the disjunction is NOT discarded.
// Call Minimize to obtain a minimal set.
func (s *EntitlementSet) Add(entitlementType *EntitlementType) {
if s.Entitlements == nil {
s.Entitlements = orderedmap.New[EntitlementOrderedSet](1)
}
s.Entitlements.Set(entitlementType, struct{}{})
}

// AddDisjunction adds an entitlement disjunction to the set.
// If the set already contains an entitlement of the given disjunction,
// then the disjunction is discarded.
func (s *EntitlementSet) AddDisjunction(disjunction *EntitlementOrderedSet) {
// If this set already contains an entitlement of the given disjunction,
// there is no need to add the disjunction.
if s.Entitlements != nil &&
disjunction.ForAnyKey(s.Entitlements.Contains) {

return
}

// If the disjunction already exists in the set,
// there is no need to add the disjunction.
key := disjunctionKey(disjunction)
if s.Disjunctions != nil && s.Disjunctions.Contains(key) {
return
}

if s.Disjunctions == nil {
s.Disjunctions = orderedmap.New[DisjunctionOrderedSet](1)
}
s.Disjunctions.Set(key, disjunction)
}

// Merge merges the other entitlement set into this set.
// The result is the union of the entitlements and disjunctions of both sets.
//
// The result is not necessarily minimal:
// For example, if s contains a disjunction d,
// and other contains an entitlement e that is part of d,
// then the result will still contain d.
// See Add.
// Call Minimize to obtain a minimal set.
func (s *EntitlementSet) Merge(other *EntitlementSet) {
if other.Entitlements != nil {
other.Entitlements.Foreach(func(key *EntitlementType, _ struct{}) {
s.Add(key)
})
}

if other.Disjunctions != nil {
other.Disjunctions.
Foreach(func(_ string, disjunction *EntitlementOrderedSet) {
s.AddDisjunction(disjunction)
})
}
}

// Minimize minimizes the entitlement set.
// It removes disjunctions that contain entitlements
// which are also in the entitlement set
func (s *EntitlementSet) Minimize() {
// If there are no entitlements or no disjunctions,
// there is nothing to minimize
if s.Entitlements == nil || s.Disjunctions == nil {
return
}

// Remove disjunctions that contain entitlements that are also in the entitlement set
var keysToRemove []string
s.Disjunctions.Foreach(func(key string, disjunction *EntitlementOrderedSet) {
if disjunction.ForAnyKey(s.Entitlements.Contains) {
keysToRemove = append(keysToRemove, key)
}
})

for _, key := range keysToRemove {
s.Disjunctions.Delete(key)
}
}

// Access returns the access represented by the entitlement set.
// The set is minimized before the access is computed.
func (s *EntitlementSet) Access() Access {
if s == nil {
return UnauthorizedAccess
}

s.Minimize()

var entitlements *EntitlementOrderedSet
if s.Entitlements != nil && s.Entitlements.Len() > 0 {
entitlements = orderedmap.New[EntitlementOrderedSet](s.Entitlements.Len())
entitlements.SetAll(s.Entitlements)
}

if s.Disjunctions != nil && s.Disjunctions.Len() > 0 {
if entitlements == nil {
// If there are no entitlements, and there is only one disjunction,
// then the access is the disjunction.
if s.Disjunctions.Len() == 1 {
onlyDisjunction := s.Disjunctions.Oldest().Value
return EntitlementSetAccess{
Entitlements: onlyDisjunction,
SetKind: Disjunction,
}
}

// There are no entitlements, but disjunctions.
// Allocate a new ordered map for all entitlements in the disjunctions
// (at minimum there are two entitlements in each disjunction).
entitlements = orderedmap.New[EntitlementOrderedSet](s.Disjunctions.Len() * 2)
}

// Add all entitlements in the disjunctions to the entitlements
s.Disjunctions.Foreach(func(_ string, disjunction *EntitlementOrderedSet) {
entitlements.SetAll(disjunction)
})
}

if entitlements == nil {
return UnauthorizedAccess
}

return EntitlementSetAccess{
Entitlements: entitlements,
SetKind: Conjunction,
}
}
Loading

0 comments on commit 93a063a

Please sign in to comment.