Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow entitlement mappings to be used in container-typed composite-fields #2587

Merged
Merged
55 changes: 26 additions & 29 deletions runtime/interpreter/interpreter_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ func (interpreter *Interpreter) memberExpressionGetterSetter(memberExpression *a
if !ok {
panic(errors.NewUnreachableError())
}
memberType := memberAccessInfo.Member.TypeAnnotation.Type

return getterSetter{
target: target,
Expand Down Expand Up @@ -225,7 +224,7 @@ func (interpreter *Interpreter) memberExpressionGetterSetter(memberExpression *a
// This is pre-computed at the checker.
if memberAccessInfo.ReturnReference {
// Get a reference to the value
resultValue = interpreter.getReferenceValue(resultValue, memberType)
resultValue = interpreter.getReferenceValue(resultValue, memberAccessInfo.ResultingType)
}

return resultValue
Expand All @@ -243,34 +242,38 @@ func (interpreter *Interpreter) memberExpressionGetterSetter(memberExpression *a
// This has to be done recursively for nested optionals.
// e.g.1: Given type T, this method returns &T.
// e.g.2: Given T?, this returns (&T)?
func (interpreter *Interpreter) getReferenceValue(value Value, semaType sema.Type) Value {
switch value.(type) {
func (interpreter *Interpreter) getReferenceValue(value Value, resultType sema.Type) Value {
switch value := value.(type) {
case NilValue, ReferenceValue:
// Reference to a nil, should return a nil.
// If the value is already a reference then return the same reference.
return value
case *SomeValue:
innerValue := interpreter.getReferenceValue(value.value, resultType)
return NewSomeValueNonCopying(interpreter, innerValue)
}

optionalType, ok := semaType.(*sema.OptionalType)
if ok {
semaType = optionalType.Type
// `resultType` is always an [optional] reference.
// This is guaranteed by the checker.
referenceType, ok := sema.UnwrapOptionalType(resultType).(*sema.ReferenceType)
if !ok {
panic(errors.NewUnreachableError())
}

// Because the boxing happens further down the code, it is possible
// to have a concrete value (non-some) for a place where optional is expected.
// Therefore, always unwrap the type, but only unwrap if the value is `SomeValue`.
//
// However, checker guarantees that the wise-versa doesn't happen.
// i.e: There will never be a `SomeValue`, with type being a non-optional.
auth := interpreter.getEffectiveAuthorization(referenceType)

if optionalValue, ok := value.(*SomeValue); ok {
value = optionalValue.value
}
innerValue := interpreter.getReferenceValue(value, semaType)
return NewSomeValueNonCopying(interpreter, innerValue)
interpreter.maybeTrackReferencedResourceKindedValue(value)
return NewEphemeralReferenceValue(interpreter, auth, value, referenceType.Type)
}

func (interpreter *Interpreter) getEffectiveAuthorization(referenceType *sema.ReferenceType) Authorization {
_, isMapped := referenceType.Authorization.(sema.EntitlementMapAccess)

if isMapped && interpreter.SharedState.currentEntitlementMappedValue != nil {
return interpreter.SharedState.currentEntitlementMappedValue
}

interpreter.maybeTrackReferencedResourceKindedValue(value)
return NewEphemeralReferenceValue(interpreter, UnauthorizedAccess, value, semaType)
return ConvertSemaAccesstoStaticAuthorization(interpreter, referenceType.Authorization)
}

func (interpreter *Interpreter) checkMemberAccess(
Expand Down Expand Up @@ -924,10 +927,10 @@ func (interpreter *Interpreter) maybeGetReference(
) Value {
indexExpressionTypes := interpreter.Program.Elaboration.IndexExpressionTypes(expression)
if indexExpressionTypes.ReturnReference {
elementType := indexExpressionTypes.IndexedType.ElementType(false)
expectedType := indexExpressionTypes.ResultType

// Get a reference to the value
memberValue = interpreter.getReferenceValue(memberValue, elementType)
memberValue = interpreter.getReferenceValue(memberValue, expectedType)
}

return memberValue
Expand Down Expand Up @@ -1224,15 +1227,9 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as
interpreter.maybeTrackReferencedResourceKindedValue(result)

makeReference := func(value Value, typ *sema.ReferenceType) *EphemeralReferenceValue {
var auth Authorization

// if we are currently interpretering a function that was declared with mapped entitlement access, any appearances
// of that mapped access in the body of the function should be replaced with the computed output of the map
if _, isMapped := typ.Authorization.(sema.EntitlementMapAccess); isMapped && interpreter.SharedState.currentEntitlementMappedValue != nil {
auth = interpreter.SharedState.currentEntitlementMappedValue
} else {
auth = ConvertSemaAccesstoStaticAuthorization(interpreter, typ.Authorization)
}
auth := interpreter.getEffectiveAuthorization(typ)

return NewEphemeralReferenceValue(
interpreter,
Expand Down
4 changes: 3 additions & 1 deletion runtime/sema/check_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ func (checker *Checker) visitIndexExpression(
// then the element type should also be a reference.
returnReference := false
if !isAssignment && shouldReturnReference(valueIndexedType, elementType) {
elementType = checker.getReferenceType(elementType)
// For index expressions, element are un-authorized.
elementType = checker.getReferenceType(elementType, false, UnauthorizedAccess)

// Store the result in elaboration, so the interpreter can re-use this.
returnReference = true
Expand All @@ -333,6 +334,7 @@ func (checker *Checker) visitIndexExpression(
IndexExpressionTypes{
IndexedType: valueIndexedType,
IndexingType: indexingType,
ResultType: elementType,
ReturnReference: returnReference,
},
)
Expand Down
147 changes: 78 additions & 69 deletions runtime/sema/check_member_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (checker *Checker) VisitMemberExpression(expression *ast.MemberExpression)
// in an optional, if it is not already an optional value
if isOptional {
if _, ok := memberType.(*OptionalType); !ok {
memberType = &OptionalType{Type: memberType}
memberType = NewOptionalType(checker.memoryGauge, memberType)
}
}

Expand All @@ -97,14 +97,18 @@ func (checker *Checker) VisitMemberExpression(expression *ast.MemberExpression)
// This has to be done recursively for nested optionals.
// e.g.1: Given type T, this method returns &T.
// e.g.2: Given T?, this returns (&T)?
func (checker *Checker) getReferenceType(typ Type) Type {
func (checker *Checker) getReferenceType(typ Type, substituteAuthorization bool, authorization Access) Type {
if optionalType, ok := typ.(*OptionalType); ok {
return &OptionalType{
Type: checker.getReferenceType(optionalType.Type),
}
innerType := checker.getReferenceType(optionalType.Type, substituteAuthorization, authorization)
return NewOptionalType(checker.memoryGauge, innerType)
}

auth := UnauthorizedAccess
if substituteAuthorization && authorization != nil {
auth = authorization
}

return NewReferenceType(checker.memoryGauge, typ, UnauthorizedAccess)
return NewReferenceType(checker.memoryGauge, typ, auth)
}

func shouldReturnReference(parentType, memberType Type) bool {
Expand Down Expand Up @@ -284,80 +288,84 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression) (accessedT
},
)
}
} else {

if checker.PositionInfo != nil {
checker.PositionInfo.recordMemberOccurrence(
accessedType,
identifier,
identifierStartPosition,
identifierEndPosition,
)
}
return
}

// Check access and report if inaccessible
accessRange := func() ast.Range { return ast.NewRangeFromPositioned(checker.memoryGauge, expression) }
isReadable, resultingAuthorization := checker.isReadableMember(accessedType, member, accessRange)
if !isReadable {
checker.report(
&InvalidAccessError{
Name: member.Identifier.Identifier,
RestrictingAccess: member.Access,
DeclarationKind: member.DeclarationKind,
Range: accessRange(),
},
)
}
if checker.PositionInfo != nil {
checker.PositionInfo.recordMemberOccurrence(
accessedType,
identifier,
identifierStartPosition,
identifierEndPosition,
)
}

// the resulting authorization was mapped through an entitlement map, so we need to substitute this new authorization into the resulting type
// i.e. if the field was declared with `access(M) let x: auth(M) &T?`, and we computed that the output of the map would give entitlement `E`,
// we substitute this entitlement in for the "variable" `M` to produce `auth(E) &T?`, the access with which the type is actually produced.
// Equivalently, this can be thought of like generic instantiation.
substituteConcreteAuthorization := func(resultingType Type) Type {
switch ty := resultingType.(type) {
// Check access and report if inaccessible
accessRange := func() ast.Range { return ast.NewRangeFromPositioned(checker.memoryGauge, expression) }
isReadable, resultingAuthorization := checker.isReadableMember(accessedType, member, accessRange)
if !isReadable {
checker.report(
&InvalidAccessError{
Name: member.Identifier.Identifier,
RestrictingAccess: member.Access,
DeclarationKind: member.DeclarationKind,
Range: accessRange(),
},
)
}

// the resulting authorization was mapped through an entitlement map, so we need to substitute this new authorization into the resulting type
// i.e. if the field was declared with `access(M) let x: auth(M) &T?`, and we computed that the output of the map would give entitlement `E`,
// we substitute this entitlement in for the "variable" `M` to produce `auth(E) &T?`, the access with which the type is actually produced.
// Equivalently, this can be thought of like generic instantiation.
substituteConcreteAuthorization := func(resultingType Type) Type {
switch ty := resultingType.(type) {
case *ReferenceType:
return NewReferenceType(checker.memoryGauge, ty.Type, resultingAuthorization)
case *OptionalType:
switch innerTy := ty.Type.(type) {
case *ReferenceType:
return NewReferenceType(checker.memoryGauge, ty.Type, resultingAuthorization)
case *OptionalType:
switch innerTy := ty.Type.(type) {
case *ReferenceType:
return NewOptionalType(checker.memoryGauge,
NewReferenceType(checker.memoryGauge, innerTy.Type, resultingAuthorization))
}
}
return resultingType
}
if !member.Access.Equal(resultingAuthorization) {
switch ty := resultingType.(type) {
case *FunctionType:
resultingType = NewSimpleFunctionType(
ty.Purity,
ty.Parameters,
NewTypeAnnotation(substituteConcreteAuthorization(ty.ReturnTypeAnnotation.Type)),
)
default:
resultingType = substituteConcreteAuthorization(resultingType)
return NewOptionalType(checker.memoryGauge,
NewReferenceType(checker.memoryGauge, innerTy.Type, resultingAuthorization))
}
}
return resultingType
}

// Check that the member access is not to a function of resource type
// outside of an invocation of it.
//
// This would result in a bound method for a resource, which is invalid.

if !checker.inAssignment &&
!checker.inInvocation &&
member.DeclarationKind == common.DeclarationKindFunction &&
!accessedType.IsInvalidType() &&
accessedType.IsResourceType() {
shouldSubstituteAuthorization := !member.Access.Equal(resultingAuthorization)

checker.report(
&ResourceMethodBindingError{
Range: ast.NewRangeFromPositioned(checker.memoryGauge, expression),
},
if shouldSubstituteAuthorization {
switch ty := resultingType.(type) {
case *FunctionType:
resultingType = NewSimpleFunctionType(
ty.Purity,
ty.Parameters,
NewTypeAnnotation(substituteConcreteAuthorization(ty.ReturnTypeAnnotation.Type)),
)
default:
resultingType = substituteConcreteAuthorization(resultingType)
}
}

// Check that the member access is not to a function of resource type
// outside of an invocation of it.
//
// This would result in a bound method for a resource, which is invalid.

if !checker.inAssignment &&
!checker.inInvocation &&
member.DeclarationKind == common.DeclarationKindFunction &&
!accessedType.IsInvalidType() &&
accessedType.IsResourceType() {

checker.report(
&ResourceMethodBindingError{
Range: ast.NewRangeFromPositioned(checker.memoryGauge, expression),
},
)
}

// If the member,
// 1) is accessed via a reference, and
// 2) is container-typed,
Expand All @@ -372,8 +380,9 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression) (accessedT
if accessedSelfMember == nil &&
shouldReturnReference(accessedType, resultingType) &&
member.DeclarationKind == common.DeclarationKindField {

// Get a reference to the type
resultingType = checker.getReferenceType(resultingType)
resultingType = checker.getReferenceType(resultingType, shouldSubstituteAuthorization, resultingAuthorization)
returnReference = true
}

Expand Down
Loading