diff --git a/docs/language/run-time-types.md b/docs/language/run-time-types.md index c8b98cb96c..f0d8306902 100644 --- a/docs/language/run-time-types.md +++ b/docs/language/run-time-types.md @@ -83,6 +83,49 @@ let type: Type = something.getType() // `type` is `Type<@Collectible>()` ``` +### Constructing a Run-time Type + +Run-time types can also be constructed from type identifier strings using built-in constructor functions. + +```cadence +fun CompositeType(_ identifier: String): Type? +fun InterfaceType(_ identifier: String): Type? +fun RestrictedType(identifier: String?, restrictions: [String]): Type? +``` + +Given a type identifer (as well as a list of identifiers for restricting interfaces +in the case of `RestrictedType`), these functions will look up nominal types and +produce their run-time equivalents. If the provided identifiers do not correspond +to any types, or (in the case of `RestrictedType`) the provided combination of +identifiers would not type-check statically, these functions will produce `nil`. + +```cadence +struct Test {} +struct interface I {} +let type: Type = CompositeType("A.0000000000000001.Test") +// `type` is `Type` + +let type2: Type = RestrictedType( + identifier: type.identifier, + restrictions: ["A.0000000000000001.I"] +) +// `type2` is `Type` +``` + +Other built-in functions will construct compound types from other run-types. + +```cadence +fun OptionalType(_ type: Type): Type +fun VariableSizedArrayType(_ type: Type): Type +fun ConstantSizedArrayType(type: Type, size: Int): Type +fun FunctionType(parameters: [Type], return: Type): Type +// returns `nil` if `key` is not valid dictionary key type +fun DictionaryType(key: Type, value: Type): Type? +// returns `nil` if `type` is not a reference type +fun CapabilityType(_ type: Type): Type? +fun ReferenceType(authorized: bool, type: Type): Type +``` + ### Asserting the Type of a Value The method `fun isInstance(_ type: Type): Bool` can be used to check if a value has a certain type, diff --git a/runtime/interpreter/errors.go b/runtime/interpreter/errors.go index 24fc54459d..6618fb16d5 100644 --- a/runtime/interpreter/errors.go +++ b/runtime/interpreter/errors.go @@ -521,3 +521,16 @@ func (e NonStorableStaticTypeError) Error() string { e.Type, ) } + +// InterfaceMissingLocation is reported during interface lookup, +// if an interface is looked up without a location +type InterfaceMissingLocationError struct { + QualifiedIdentifier string +} + +func (e *InterfaceMissingLocationError) Error() string { + return fmt.Sprintf( + "tried to look up interface %s without a location", + e.QualifiedIdentifier, + ) +} diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 7fbc88f392..0c7ebfc58a 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -2615,6 +2615,36 @@ var converterDeclarations = []valueConverterDeclaration{ }, } +func lookupInterface(interpreter *Interpreter, typeID string) (*sema.InterfaceType, error) { + location, qualifiedIdentifier, err := common.DecodeTypeID(typeID) + // if the typeID is invalid, return nil + if err != nil { + return nil, err + } + + typ, err := interpreter.getInterfaceType(location, qualifiedIdentifier) + if err != nil { + return nil, err + } + + return typ, nil +} + +func lookupComposite(interpreter *Interpreter, typeID string) (*sema.CompositeType, error) { + location, qualifiedIdentifier, err := common.DecodeTypeID(typeID) + // if the typeID is invalid, return nil + if err != nil { + return nil, err + } + + typ, err := interpreter.getCompositeType(location, qualifiedIdentifier, common.TypeID(typeID)) + if err != nil { + return nil, err + } + + return typ, nil +} + func init() { converterNames := make(map[string]struct{}, len(converterDeclarations)) @@ -2639,11 +2669,173 @@ func init() { panic(fmt.Sprintf("missing converter for number type: %s", numberType)) } } + + // We assign this here because it depends on the interpreter, so this breaks the initialization cycle + defineBaseValue( + baseActivation, + "DictionaryType", + NewHostFunctionValue( + func(invocation Invocation) Value { + keyType := invocation.Arguments[0].(TypeValue).Type + valueType := invocation.Arguments[1].(TypeValue).Type + + // if the given key is not a valid dictionary key, it wouldn't make sense to create this type + if keyType == nil || + !sema.IsValidDictionaryKeyType(invocation.Interpreter.ConvertStaticToSemaType(keyType)) { + return NilValue{} + } + + return NewSomeValueNonCopying(TypeValue{ + Type: DictionaryStaticType{ + KeyType: keyType, + ValueType: valueType, + }}) + }, + sema.DictionaryTypeFunctionType, + )) + + defineBaseValue( + baseActivation, + "CompositeType", + NewHostFunctionValue( + func(invocation Invocation) Value { + typeID := invocation.Arguments[0].(*StringValue).Str + + composite, err := lookupComposite(invocation.Interpreter, typeID) + if err != nil { + return NilValue{} + } + + return NewSomeValueNonCopying(TypeValue{ + Type: ConvertSemaToStaticType(composite), + }) + }, + sema.CompositeTypeFunctionType, + ), + ) + + defineBaseValue( + baseActivation, + "InterfaceType", + NewHostFunctionValue( + func(invocation Invocation) Value { + typeID := invocation.Arguments[0].(*StringValue).Str + + interfaceType, err := lookupInterface(invocation.Interpreter, typeID) + if err != nil { + return NilValue{} + } + + return NewSomeValueNonCopying(TypeValue{ + Type: ConvertSemaToStaticType(interfaceType), + }) + }, + sema.InterfaceTypeFunctionType, + ), + ) + + defineBaseValue( + baseActivation, + "FunctionType", + NewHostFunctionValue( + func(invocation Invocation) Value { + parameters := invocation.Arguments[0].(*ArrayValue) + returnType := invocation.Interpreter.ConvertStaticToSemaType(invocation.Arguments[1].(TypeValue).Type) + parameterTypes := make([]*sema.Parameter, 0, parameters.Count()) + parameters.Iterate(func(param Value) bool { + parameterTypes = append( + parameterTypes, + &sema.Parameter{ + TypeAnnotation: sema.NewTypeAnnotation(invocation.Interpreter.ConvertStaticToSemaType(param.(TypeValue).Type)), + }, + ) + return true + }) + return TypeValue{ + Type: FunctionStaticType{ + Type: &sema.FunctionType{ + ReturnTypeAnnotation: sema.NewTypeAnnotation(returnType), + Parameters: parameterTypes, + }, + }} + }, + sema.FunctionTypeFunctionType, + ), + ) + + defineBaseValue( + baseActivation, + "RestrictedType", + NewHostFunctionValue( + RestrictedTypeFunction, + sema.RestrictedTypeFunctionType, + ), + ) +} + +func RestrictedTypeFunction(invocation Invocation) Value { + restrictedIDs := invocation.Arguments[1].(*ArrayValue) + + staticRestrictions := make([]InterfaceStaticType, 0, restrictedIDs.Count()) + semaRestrictions := make([]*sema.InterfaceType, 0, restrictedIDs.Count()) + ok := true + + restrictedIDs.Iterate(func(typeID Value) bool { + restrictionInterface, err := lookupInterface(invocation.Interpreter, typeID.(*StringValue).Str) + if err != nil { + ok = false + return true + } + + staticRestrictions = append(staticRestrictions, ConvertSemaToStaticType(restrictionInterface).(InterfaceStaticType)) + semaRestrictions = append(semaRestrictions, restrictionInterface) + return true + }) + + if !ok { + return NilValue{} + } + + var semaType sema.Type + var err error + + switch typeID := invocation.Arguments[0].(type) { + case NilValue: + semaType = nil + case *SomeValue: + semaType, err = lookupComposite(invocation.Interpreter, typeID.Value.(*StringValue).Str) + if err != nil { + return NilValue{} + } + default: + panic(errors.NewUnreachableError()) + } + + ok = true + ty := sema.CheckRestrictedType( + semaType, + semaRestrictions, + func(_ func(*ast.RestrictedType) error) { + ok = false + }, + ) + + // if the restricted type would have failed to typecheck statically, we return nil + if !ok { + return NilValue{} + } + return NewSomeValueNonCopying(TypeValue{ + Type: &RestrictedStaticType{ + Type: ConvertSemaToStaticType(ty), + Restrictions: staticRestrictions, + }, + }) } func defineBaseFunctions(activation *VariableActivation) { defineConverterFunctions(activation) defineTypeFunction(activation) + defineRuntimeTypeConstructorFunctions(activation) defineStringFunction(activation) } @@ -2701,6 +2893,94 @@ func defineConverterFunctions(activation *VariableActivation) { } } +type runtimeTypeConstructor struct { + name string + converter *HostFunctionValue +} + +// Constructor functions are stateless functions. Hence they can be re-used across interpreters. +// +var runtimeTypeConstructors = []runtimeTypeConstructor{ + { + name: "OptionalType", + converter: NewHostFunctionValue( + func(invocation Invocation) Value { + return TypeValue{ + Type: OptionalStaticType{ + Type: invocation.Arguments[0].(TypeValue).Type, + }} + + }, + sema.OptionalTypeFunctionType, + ), + }, + { + name: "VariableSizedArrayType", + converter: NewHostFunctionValue( + func(invocation Invocation) Value { + return TypeValue{ + Type: VariableSizedStaticType{ + Type: invocation.Arguments[0].(TypeValue).Type, + }} + }, + sema.VariableSizedArrayTypeFunctionType, + ), + }, + { + name: "ConstantSizedArrayType", + converter: NewHostFunctionValue( + func(invocation Invocation) Value { + return TypeValue{ + Type: ConstantSizedStaticType{ + Type: invocation.Arguments[0].(TypeValue).Type, + Size: int64(invocation.Arguments[1].(IntValue).ToInt()), + }} + }, + sema.ConstantSizedArrayTypeFunctionType, + ), + }, + { + name: "ReferenceType", + converter: NewHostFunctionValue( + func(invocation Invocation) Value { + return TypeValue{ + Type: ReferenceStaticType{ + Authorized: bool(invocation.Arguments[0].(BoolValue)), + Type: invocation.Arguments[1].(TypeValue).Type, + }} + }, + sema.ReferenceTypeFunctionType, + ), + }, + { + name: "CapabilityType", + converter: NewHostFunctionValue( + func(invocation Invocation) Value { + ty := invocation.Arguments[0].(TypeValue).Type + // Capabilities must hold references + _, ok := ty.(ReferenceStaticType) + if !ok { + return NilValue{} + } + return NewSomeValueNonCopying( + TypeValue{ + Type: CapabilityStaticType{ + BorrowType: ty, + }, + }, + ) + }, + sema.CapabilityTypeFunctionType, + ), + }, +} + +func defineRuntimeTypeConstructorFunctions(activation *VariableActivation) { + for _, constructorFunc := range runtimeTypeConstructors { + defineBaseValue(activation, constructorFunc.name, constructorFunc.converter) + } +} + // typeFunction is the `Type` function. It is stateless, hence it can be re-used across interpreters. // var typeFunction = NewHostFunctionValue( @@ -3470,14 +3750,18 @@ func (interpreter *Interpreter) ConvertStaticToSemaType(staticType StaticType) s return ConvertStaticToSemaType( staticType, func(location common.Location, qualifiedIdentifier string) *sema.InterfaceType { - return interpreter.getInterfaceType(location, qualifiedIdentifier) + interfaceType, err := interpreter.getInterfaceType(location, qualifiedIdentifier) + if err != nil { + panic(err) + } + return interfaceType }, func(location common.Location, qualifiedIdentifier string, typeID common.TypeID) *sema.CompositeType { - if location == nil { - return interpreter.getNativeCompositeType(qualifiedIdentifier) + compositeType, err := interpreter.getCompositeType(location, qualifiedIdentifier, typeID) + if err != nil { + panic(err) } - - return interpreter.getUserCompositeType(location, typeID) + return compositeType }, ) } @@ -3521,52 +3805,64 @@ func (interpreter *Interpreter) GetContractComposite(contractLocation common.Add return contractValue, nil } -func (interpreter *Interpreter) getUserCompositeType(location common.Location, typeID common.TypeID) *sema.CompositeType { +func (interpreter *Interpreter) getCompositeType(location common.Location, qualifiedIdentifier string, typeID common.TypeID) (*sema.CompositeType, error) { + if location == nil { + return interpreter.getNativeCompositeType(qualifiedIdentifier) + } + + return interpreter.getUserCompositeType(location, typeID) +} + +func (interpreter *Interpreter) getUserCompositeType(location common.Location, typeID common.TypeID) (*sema.CompositeType, error) { elaboration := interpreter.getElaboration(location) if elaboration == nil { - panic(TypeLoadingError{ + return nil, TypeLoadingError{ TypeID: typeID, - }) + } } ty := elaboration.CompositeTypes[typeID] if ty == nil { - panic(TypeLoadingError{ + return nil, TypeLoadingError{ TypeID: typeID, - }) + } } - return ty + return ty, nil } -func (interpreter *Interpreter) getNativeCompositeType(qualifiedIdentifier string) *sema.CompositeType { +func (interpreter *Interpreter) getNativeCompositeType(qualifiedIdentifier string) (*sema.CompositeType, error) { ty := sema.NativeCompositeTypes[qualifiedIdentifier] if ty == nil { - panic(TypeLoadingError{ + return ty, TypeLoadingError{ TypeID: common.TypeID(qualifiedIdentifier), - }) + } } - return ty + return ty, nil } -func (interpreter *Interpreter) getInterfaceType(location common.Location, qualifiedIdentifier string) *sema.InterfaceType { +func (interpreter *Interpreter) getInterfaceType(location common.Location, qualifiedIdentifier string) (*sema.InterfaceType, error) { + if location == nil { + return nil, &InterfaceMissingLocationError{QualifiedIdentifier: qualifiedIdentifier} + } + typeID := location.TypeID(qualifiedIdentifier) elaboration := interpreter.getElaboration(location) if elaboration == nil { - panic(TypeLoadingError{ + return nil, TypeLoadingError{ TypeID: typeID, - }) + } } ty := elaboration.InterfaceTypes[typeID] if ty == nil { - panic(TypeLoadingError{ + return nil, TypeLoadingError{ TypeID: typeID, - }) + } } - return ty + return ty, nil } func (interpreter *Interpreter) reportLoopIteration(pos ast.HasPosition) { diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index ad088a8d7b..1a8aff2f1a 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -8047,10 +8047,14 @@ func (v *CompositeValue) Walk(walkChild func(Value)) { func (v *CompositeValue) DynamicType(interpreter *Interpreter, _ SeenReferences) DynamicType { if v.dynamicType == nil { var staticType sema.Type + var err error if v.Location == nil { - staticType = interpreter.getNativeCompositeType(v.QualifiedIdentifier) + staticType, err = interpreter.getNativeCompositeType(v.QualifiedIdentifier) } else { - staticType = interpreter.getUserCompositeType(v.Location, v.TypeID()) + staticType, err = interpreter.getUserCompositeType(v.Location, v.TypeID()) + } + if err != nil { + panic(err) } v.dynamicType = CompositeDynamicType{ StaticType: staticType, diff --git a/runtime/sema/checker.go b/runtime/sema/checker.go index 579457c2e7..770021328a 100644 --- a/runtime/sema/checker.go +++ b/runtime/sema/checker.go @@ -938,76 +938,42 @@ func (checker *Checker) ConvertType(t ast.Type) Type { panic(&astTypeConversionError{invalidASTType: t}) } -func (checker *Checker) convertRestrictedType(t *ast.RestrictedType) Type { - var restrictedType Type - - // Convert the restricted type, if any - - if t.Type != nil { - restrictedType = checker.ConvertType(t.Type) - } - - // Convert the restrictions - - var restrictions []*InterfaceType - restrictionRanges := make(map[*InterfaceType]ast.Range, len(t.Restrictions)) - - memberSet := map[string]*InterfaceType{} - +func CheckRestrictedType(restrictedType Type, restrictions []*InterfaceType, report func(func(*ast.RestrictedType) error)) Type { + restrictionRanges := make(map[*InterfaceType]func(*ast.RestrictedType) ast.Range, len(restrictions)) restrictionsCompositeKind := common.CompositeKindUnknown + memberSet := map[string]*InterfaceType{} - for _, restriction := range t.Restrictions { - restrictionResult := checker.ConvertType(restriction) - - // The restriction must be a resource or structure interface type - - restrictionInterfaceType, ok := restrictionResult.(*InterfaceType) - restrictionCompositeKind := common.CompositeKindUnknown - if ok { - restrictionCompositeKind = restrictionInterfaceType.CompositeKind - } - if !ok || (restrictionCompositeKind != common.CompositeKindResource && - restrictionCompositeKind != common.CompositeKindStructure) { - - if !restrictionResult.IsInvalidType() { - checker.report( - &InvalidRestrictionTypeError{ - Type: restrictionResult, - Range: ast.NewRangeFromPositioned(restriction), - }, - ) - } - continue - } + for i, restrictionInterfaceType := range restrictions { + restrictionCompositeKind := restrictionInterfaceType.CompositeKind if restrictionsCompositeKind == common.CompositeKindUnknown { restrictionsCompositeKind = restrictionCompositeKind } else if restrictionCompositeKind != restrictionsCompositeKind { - - checker.report( - &RestrictionCompositeKindMismatchError{ + report(func(t *ast.RestrictedType) error { + return &RestrictionCompositeKindMismatchError{ CompositeKind: restrictionCompositeKind, PreviousCompositeKind: restrictionsCompositeKind, - Range: ast.NewRangeFromPositioned(restriction), - }, - ) + Range: ast.NewRangeFromPositioned(t.Restrictions[i]), + } + }) } - restrictions = append(restrictions, restrictionInterfaceType) - // The restriction must not be duplicated if _, exists := restrictionRanges[restrictionInterfaceType]; exists { - checker.report( - &InvalidRestrictionTypeDuplicateError{ + report(func(t *ast.RestrictedType) error { + return &InvalidRestrictionTypeDuplicateError{ Type: restrictionInterfaceType, - Range: ast.NewRangeFromPositioned(restriction), - }, - ) + Range: ast.NewRangeFromPositioned(t.Restrictions[i]), + } + }) + } else { restrictionRanges[restrictionInterfaceType] = - ast.NewRangeFromPositioned(restriction) + func(t *ast.RestrictedType) ast.Range { + return ast.NewRangeFromPositioned(t.Restrictions[i]) + } } // The restrictions may not have clashing members @@ -1033,14 +999,14 @@ func (checker *Checker) convertRestrictedType(t *ast.RestrictedType) Type { !previousMemberType.IsInvalidType() && !memberType.Equal(previousMemberType) { - checker.report( - &RestrictionMemberClashError{ + report(func(t *ast.RestrictedType) error { + return &RestrictionMemberClashError{ Name: name, RedeclaringType: restrictionInterfaceType, OriginalDeclaringType: previousDeclaringInterfaceType, - Range: ast.NewRangeFromPositioned(restriction), - }, - ) + Range: ast.NewRangeFromPositioned(t.Restrictions[i]), + } + }) } } else { memberSet[name] = restrictionInterfaceType @@ -1048,7 +1014,9 @@ func (checker *Checker) convertRestrictedType(t *ast.RestrictedType) Type { }) } - if restrictedType == nil { + var hadExplicitType = restrictedType != nil + + if !hadExplicitType { // If no restricted type is given, infer `AnyResource`/`AnyStruct` // based on the composite kind of the restrictions. @@ -1059,11 +1027,9 @@ func (checker *Checker) convertRestrictedType(t *ast.RestrictedType) Type { restrictedType = InvalidType - checker.report( - &AmbiguousRestrictedTypeError{ - Range: ast.NewRangeFromPositioned(t), - }, - ) + report(func(t *ast.RestrictedType) error { + return &AmbiguousRestrictedTypeError{Range: ast.NewRangeFromPositioned(t)} + }) case common.CompositeKindResource: restrictedType = AnyResourceType @@ -1080,12 +1046,12 @@ func (checker *Checker) convertRestrictedType(t *ast.RestrictedType) Type { // or `AnyResource`/`AnyStruct` reportInvalidRestrictedType := func() { - checker.report( - &InvalidRestrictedTypeError{ + report(func(t *ast.RestrictedType) error { + return &InvalidRestrictedTypeError{ Type: restrictedType, Range: ast.NewRangeFromPositioned(t.Type), - }, - ) + } + }) } var compositeType *CompositeType @@ -1110,7 +1076,7 @@ func (checker *Checker) convertRestrictedType(t *ast.RestrictedType) Type { break default: - if t.Type != nil { + if hadExplicitType { reportInvalidRestrictedType() } } @@ -1131,16 +1097,63 @@ func (checker *Checker) convertRestrictedType(t *ast.RestrictedType) Type { // of the composite (restricted type) if !conformances.Includes(restriction) { - checker.report( - &InvalidNonConformanceRestrictionError{ + report(func(t *ast.RestrictedType) error { + return &InvalidNonConformanceRestrictionError{ Type: restriction, - Range: restrictionRanges[restriction], - }, - ) + Range: restrictionRanges[restriction](t), + } + }) + } + } + } + return restrictedType +} + +func (checker *Checker) convertRestrictedType(t *ast.RestrictedType) Type { + var restrictedType Type + + // Convert the restricted type, if any + + if t.Type != nil { + restrictedType = checker.ConvertType(t.Type) + } + + // Convert the restrictions + + var restrictions []*InterfaceType + + for _, restriction := range t.Restrictions { + restrictionResult := checker.ConvertType(restriction) + + // The restriction must be a resource or structure interface type + + restrictionInterfaceType, ok := restrictionResult.(*InterfaceType) + restrictionCompositeKind := common.CompositeKindUnknown + if ok { + restrictionCompositeKind = restrictionInterfaceType.CompositeKind + } + if !ok || (restrictionCompositeKind != common.CompositeKindResource && + restrictionCompositeKind != common.CompositeKindStructure) { + + if !restrictionResult.IsInvalidType() { + checker.report(&InvalidRestrictionTypeError{ + Type: restrictionResult, + Range: ast.NewRangeFromPositioned(restriction), + }) } } + + restrictions = append(restrictions, restrictionInterfaceType) } + restrictedType = CheckRestrictedType( + restrictedType, + restrictions, + func(getError func(*ast.RestrictedType) error) { + checker.report(getError(t)) + }, + ) + return &RestrictedType{ Type: restrictedType, Restrictions: restrictions, diff --git a/runtime/sema/runtime_type_constructors.go b/runtime/sema/runtime_type_constructors.go new file mode 100644 index 0000000000..2ca9027f2b --- /dev/null +++ b/runtime/sema/runtime_type_constructors.go @@ -0,0 +1,216 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright 2019-2021 Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema + +type RuntimeTypeConstructor struct { + Name string + Value *FunctionType + DocString string +} + +var OptionalTypeFunctionType = &FunctionType{ + Parameters: []*Parameter{ + { + Label: ArgumentLabelNotRequired, + Identifier: "type", + TypeAnnotation: NewTypeAnnotation(MetaType), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(MetaType), +} + +var VariableSizedArrayTypeFunctionType = &FunctionType{ + Parameters: []*Parameter{ + { + Label: ArgumentLabelNotRequired, + Identifier: "type", + TypeAnnotation: NewTypeAnnotation(MetaType), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(MetaType), +} + +var ConstantSizedArrayTypeFunctionType = &FunctionType{ + Parameters: []*Parameter{ + { + Identifier: "type", + TypeAnnotation: NewTypeAnnotation(MetaType), + }, + { + Identifier: "size", + TypeAnnotation: NewTypeAnnotation(IntType), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(MetaType), +} + +var DictionaryTypeFunctionType = &FunctionType{ + Parameters: []*Parameter{ + { + Identifier: "key", + TypeAnnotation: NewTypeAnnotation(MetaType), + }, + { + Identifier: "value", + TypeAnnotation: NewTypeAnnotation(MetaType), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(&OptionalType{MetaType}), +} + +var CompositeTypeFunctionType = &FunctionType{ + Parameters: []*Parameter{ + { + Label: ArgumentLabelNotRequired, + Identifier: "identifier", + TypeAnnotation: NewTypeAnnotation(StringType), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(&OptionalType{MetaType}), +} + +var InterfaceTypeFunctionType = &FunctionType{ + Parameters: []*Parameter{ + { + Label: ArgumentLabelNotRequired, + Identifier: "identifier", + TypeAnnotation: NewTypeAnnotation(StringType), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(&OptionalType{MetaType}), +} + +var FunctionTypeFunctionType = &FunctionType{ + Parameters: []*Parameter{ + { + Identifier: "parameters", + TypeAnnotation: NewTypeAnnotation(&VariableSizedType{Type: MetaType}), + }, + { + Identifier: "return", + TypeAnnotation: NewTypeAnnotation(MetaType), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(MetaType), +} + +var RestrictedTypeFunctionType = &FunctionType{ + Parameters: []*Parameter{ + { + Identifier: "identifier", + TypeAnnotation: NewTypeAnnotation(&OptionalType{StringType}), + }, + { + Identifier: "restrictions", + TypeAnnotation: NewTypeAnnotation(&VariableSizedType{Type: StringType}), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(&OptionalType{MetaType}), +} + +var ReferenceTypeFunctionType = &FunctionType{ + Parameters: []*Parameter{ + { + Identifier: "authorized", + TypeAnnotation: NewTypeAnnotation(BoolType), + }, + { + Identifier: "type", + TypeAnnotation: NewTypeAnnotation(MetaType), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(MetaType), +} + +var CapabilityTypeFunctionType = &FunctionType{ + Parameters: []*Parameter{ + { + Label: ArgumentLabelNotRequired, + Identifier: "type", + TypeAnnotation: NewTypeAnnotation(MetaType), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation(&OptionalType{MetaType}), +} + +var runtimeTypeConstructors = []*RuntimeTypeConstructor{ + { + "OptionalType", + OptionalTypeFunctionType, + "Creates a run-time type representing an optional version of the given run-time type.", + }, + + { + "VariableSizedArrayType", + VariableSizedArrayTypeFunctionType, + "Creates a run-time type representing a variable-sized array type of the given run-time type.", + }, + + { + "ConstantSizedArrayType", + ConstantSizedArrayTypeFunctionType, + "Creates a run-time type representing a constant-sized array type of the given run-time type with the specifized size.", + }, + + { + "DictionaryType", + DictionaryTypeFunctionType, + `Creates a run-time type representing a dictionary type of the given run-time key and value types. + Returns nil if the key type is not a valid dictionary key.`, + }, + + { + "CompositeType", + CompositeTypeFunctionType, + `Creates a run-time type representing the composite type associated with the given type identifier. + Returns nil if the identifier does not correspond to any composite type.`, + }, + + { + "InterfaceType", + InterfaceTypeFunctionType, + `Creates a run-time type representing the interface type associated with the given type identifier. + Returns nil if the identifier does not correspond to any interface type.`, + }, + + { + "FunctionType", + FunctionTypeFunctionType, + "Creates a run-time type representing a function type associated with the given parameters and return type.", + }, + + { + "ReferenceType", + ReferenceTypeFunctionType, + "Creates a run-time type representing a reference type of the given type, with authorization provided by the first argument.", + }, + + { + "RestrictedType", + RestrictedTypeFunctionType, + `Creates a run-time type representing a restricted type of the first argument, restricted by the interface identifiers in the second argument. + Returns nil if the restriction is not valid.`, + }, + + { + "CapabilityType", + CapabilityTypeFunctionType, + "Creates a run-time type representing a capability type of the given reference type. Returns nil if the type is not a reference.", + }, +} diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 488426c2b6..ca8be6dbbd 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -2990,6 +2990,7 @@ func baseFunctionVariable(name string, ty *FunctionType, docString string) *Vari return &Variable{ Identifier: name, DeclarationKind: common.DeclarationKindFunction, + ArgumentLabels: ty.ArgumentLabels(), IsConstant: true, IsBaseValue: true, Type: ty, @@ -3259,6 +3260,16 @@ func init() { "Creates a run-time type representing the given static type as a value", ), ) + + for _, v := range runtimeTypeConstructors { + BaseValueActivation.Set( + v.Name, + baseFunctionVariable( + v.Name, + v.Value, + v.DocString, + )) + } } // CompositeType diff --git a/runtime/tests/checker/runtimetype_test.go b/runtime/tests/checker/runtimetype_test.go new file mode 100644 index 0000000000..9dfb8270e0 --- /dev/null +++ b/runtime/tests/checker/runtimetype_test.go @@ -0,0 +1,854 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright 2019-2021 Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package checker + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/onflow/cadence/runtime/sema" +) + +func TestCheckOptionalTypeConstructor(t *testing.T) { + + t.Parallel() + + cases := []struct { + name string + code string + expectedError error + }{ + { + name: "String", + code: ` + let result = OptionalType(Type()) + `, + expectedError: nil, + }, + { + name: "Int", + code: ` + let result = OptionalType(Type()) + `, + expectedError: nil, + }, + { + name: "resource", + code: ` + resource R {} + let result = OptionalType(Type<@R>()) + `, + expectedError: nil, + }, + { + name: "type mismatch", + code: ` + let result = OptionalType(3) + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "too many args", + code: ` + let result = OptionalType(Type(), Type()) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "too few args", + code: ` + let result = OptionalType() + `, + expectedError: &sema.ArgumentCountError{}, + }, + } + + for _, testCase := range cases { + t.Run(testCase.name, func(t *testing.T) { + checker, err := ParseAndCheck(t, testCase.code) + + if testCase.expectedError == nil { + require.NoError(t, err) + assert.Equal(t, + sema.MetaType, + RequireGlobalValue(t, checker.Elaboration, "result"), + ) + } else { + errs := ExpectCheckerErrors(t, err, 1) + assert.IsType(t, testCase.expectedError, errs[0]) + } + }) + } +} + +func TestCheckVariableSizedArrayTypeConstructor(t *testing.T) { + + t.Parallel() + + cases := []struct { + name string + code string + expectedError error + }{ + { + name: "String", + code: ` + let result = VariableSizedArrayType(Type()) + `, + expectedError: nil, + }, + { + name: "Int", + code: ` + let result = VariableSizedArrayType(Type()) + `, + expectedError: nil, + }, + { + name: "resource", + code: ` + resource R {} + let result = VariableSizedArrayType(Type<@R>()) + `, + expectedError: nil, + }, + { + name: "type mismatch", + code: ` + let result = VariableSizedArrayType(3) + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "too many args", + code: ` + let result = VariableSizedArrayType(Type(), Type()) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "too few args", + code: ` + let result = VariableSizedArrayType() + `, + expectedError: &sema.ArgumentCountError{}, + }, + } + + for _, testCase := range cases { + t.Run(testCase.name, func(t *testing.T) { + checker, err := ParseAndCheck(t, testCase.code) + + if testCase.expectedError == nil { + require.NoError(t, err) + assert.Equal(t, + sema.MetaType, + RequireGlobalValue(t, checker.Elaboration, "result"), + ) + } else { + errs := ExpectCheckerErrors(t, err, 1) + assert.IsType(t, testCase.expectedError, errs[0]) + } + }) + } +} + +func TestCheckConstantSizedArrayTypeConstructor(t *testing.T) { + + t.Parallel() + + cases := []struct { + name string + code string + expectedError error + }{ + { + name: "String", + code: ` + let result = ConstantSizedArrayType(type: Type(), size: 3) + `, + expectedError: nil, + }, + { + name: "Int", + code: ` + let result = ConstantSizedArrayType(type: Type(), size: 2) + `, + expectedError: nil, + }, + { + name: "resource", + code: ` + resource R {} + let result = ConstantSizedArrayType(type: Type<@R>(), size: 4) + `, + expectedError: nil, + }, + { + name: "type mismatch first arg", + code: ` + let result = ConstantSizedArrayType(type: 3, size: 4) + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "type mismatch second arg", + code: ` + let result = ConstantSizedArrayType(type: Type(), size: "") + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "too many args", + code: ` + let result = ConstantSizedArrayType(type:Type(), size: 3, 4) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "one arg", + code: ` + let result = ConstantSizedArrayType(type: Type()) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "no args", + code: ` + let result = ConstantSizedArrayType() + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "second label missing", + code: ` + let result = ConstantSizedArrayType(type: Type(), 3) + `, + expectedError: &sema.MissingArgumentLabelError{}, + }, + { + name: "first label missing", + code: ` + let result = ConstantSizedArrayType(Type(), size: 3) + `, + expectedError: &sema.MissingArgumentLabelError{}, + }, + } + + for _, testCase := range cases { + t.Run(testCase.name, func(t *testing.T) { + checker, err := ParseAndCheck(t, testCase.code) + + if testCase.expectedError == nil { + require.NoError(t, err) + assert.Equal(t, + sema.MetaType, + RequireGlobalValue(t, checker.Elaboration, "result"), + ) + } else { + errs := ExpectCheckerErrors(t, err, 1) + assert.IsType(t, testCase.expectedError, errs[0]) + } + }) + } +} + +func TestCheckDictionaryTypeConstructor(t *testing.T) { + + t.Parallel() + + cases := []struct { + name string + code string + expectedError error + }{ + { + name: "String/Int", + code: ` + let result = DictionaryType(key: Type(), value: Type()) + `, + expectedError: nil, + }, + { + name: "Int/String", + code: ` + let result = DictionaryType(key: Type(), value: Type()) + `, + expectedError: nil, + }, + { + name: "resource/struct", + code: ` + resource R {} + struct S {} + let result = DictionaryType(key: Type<@R>(), value: Type()) + `, + expectedError: nil, + }, + { + name: "type mismatch first arg", + code: ` + let result = DictionaryType(key: 3, value: Type()) + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "type mismatch second arg", + code: ` + let result = DictionaryType(key: Type(), value: "") + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "too many args", + code: ` + let result = DictionaryType(key: Type(), value: Type(), 4) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "one arg", + code: ` + let result = DictionaryType(key: Type()) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "no args", + code: ` + let result = DictionaryType() + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "first label missing", + code: ` + let result = DictionaryType(Type(), value: Type()) + `, + expectedError: &sema.MissingArgumentLabelError{}, + }, + { + name: "second label missing", + code: ` + let result = DictionaryType(key: Type(), Type()) + `, + expectedError: &sema.MissingArgumentLabelError{}, + }, + } + + for _, testCase := range cases { + t.Run(testCase.name, func(t *testing.T) { + checker, err := ParseAndCheck(t, testCase.code) + + if testCase.expectedError == nil { + require.NoError(t, err) + assert.Equal(t, + &sema.OptionalType{Type: sema.MetaType}, + RequireGlobalValue(t, checker.Elaboration, "result"), + ) + } else { + errs := ExpectCheckerErrors(t, err, 1) + assert.IsType(t, testCase.expectedError, errs[0]) + } + }) + } +} + +func TestCheckCompositeTypeConstructor(t *testing.T) { + + t.Parallel() + + cases := []struct { + name string + code string + expectedError error + }{ + { + name: "R", + code: ` + let result = CompositeType("R") + `, + expectedError: nil, + }, + { + name: "type mismatch", + code: ` + let result = CompositeType(3) + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "too many args", + code: ` + let result = CompositeType("", 3) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "no args", + code: ` + let result = CompositeType() + `, + expectedError: &sema.ArgumentCountError{}, + }, + } + + for _, testCase := range cases { + t.Run(testCase.name, func(t *testing.T) { + checker, err := ParseAndCheck(t, testCase.code) + + if testCase.expectedError == nil { + require.NoError(t, err) + assert.Equal(t, + &sema.OptionalType{Type: sema.MetaType}, + RequireGlobalValue(t, checker.Elaboration, "result"), + ) + } else { + errs := ExpectCheckerErrors(t, err, 1) + assert.IsType(t, testCase.expectedError, errs[0]) + } + }) + } +} + +func TestCheckInterfaceTypeConstructor(t *testing.T) { + + t.Parallel() + + cases := []struct { + name string + code string + expectedError error + }{ + { + name: "R", + code: ` + let result = InterfaceType("R") + `, + expectedError: nil, + }, + { + name: "type mismatch", + code: ` + let result = InterfaceType(3) + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "too many args", + code: ` + let result = InterfaceType("", 3) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "no args", + code: ` + let result = InterfaceType() + `, + expectedError: &sema.ArgumentCountError{}, + }, + } + + for _, testCase := range cases { + t.Run(testCase.name, func(t *testing.T) { + checker, err := ParseAndCheck(t, testCase.code) + + if testCase.expectedError == nil { + require.NoError(t, err) + assert.Equal(t, + &sema.OptionalType{Type: sema.MetaType}, + RequireGlobalValue(t, checker.Elaboration, "result"), + ) + } else { + errs := ExpectCheckerErrors(t, err, 1) + assert.IsType(t, testCase.expectedError, errs[0]) + } + }) + } +} + +func TestCheckFunctionTypeConstructor(t *testing.T) { + + t.Parallel() + + cases := []struct { + name string + code string + expectedError error + }{ + { + name: "(String): Int", + code: ` + let result = FunctionType(parameters: [Type()], return: Type()) + `, + expectedError: nil, + }, + { + name: "(String, Int): Bool", + code: ` + let result = FunctionType(parameters: [Type(), Type()], return: Type()) + `, + expectedError: nil, + }, + { + name: "type mismatch first arg", + code: ` + let result = FunctionType(parameters: Type(), return: Type()) + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "type mismatch nested first arg", + code: ` + let result = FunctionType(parameters: [Type(), 3], return: Type()) + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "type mismatch second arg", + code: ` + let result = FunctionType(parameters: [Type(), Type()], return: "") + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "too many args", + code: ` + let result = FunctionType(parameters: [Type(), Type()], return: Type(), 4) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "one arg", + code: ` + let result = FunctionType(parameters: [Type(), Type()]) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "no args", + code: ` + let result = FunctionType() + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "first label missing", + code: ` + let result = FunctionType([Type()], return: Type()) + `, + expectedError: &sema.MissingArgumentLabelError{}, + }, + { + name: "second label missing", + code: ` + let result = FunctionType(parameters: [Type()], Type()) + `, + expectedError: &sema.MissingArgumentLabelError{}, + }, + } + + for _, testCase := range cases { + t.Run(testCase.name, func(t *testing.T) { + checker, err := ParseAndCheck(t, testCase.code) + + if testCase.expectedError == nil { + require.NoError(t, err) + assert.Equal(t, + sema.MetaType, + RequireGlobalValue(t, checker.Elaboration, "result"), + ) + } else { + errs := ExpectCheckerErrors(t, err, 1) + assert.IsType(t, testCase.expectedError, errs[0]) + } + }) + } +} + +func TestCheckReferenceTypeConstructor(t *testing.T) { + + t.Parallel() + + cases := []struct { + name string + code string + expectedError error + }{ + { + name: "auth &R", + code: ` + resource R {} + let result = ReferenceType(authorized: true, type: Type<@R>()) + `, + expectedError: nil, + }, + { + name: "&String", + code: ` + let result = ReferenceType(authorized: false, type: Type()) + `, + expectedError: nil, + }, + { + name: "type mismatch first arg", + code: ` + let result = ReferenceType(authorized: "", type: Type()) + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "type mismatch second arg", + code: ` + let result = ReferenceType(authorized: true, type: "") + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "too many args", + code: ` + let result = ReferenceType(authorized: true, type: Type(), Type()) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "one arg", + code: ` + let result = ReferenceType(authorized: true) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "no args", + code: ` + let result = ReferenceType() + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "first label missing", + code: ` + resource R {} + let result = ReferenceType(true, type: Type<@R>()) + `, + expectedError: &sema.MissingArgumentLabelError{}, + }, + { + name: "second label missing", + code: ` + resource R {} + let result = ReferenceType(authorized: true, Type<@R>()) + `, + expectedError: &sema.MissingArgumentLabelError{}, + }, + } + + for _, testCase := range cases { + t.Run(testCase.name, func(t *testing.T) { + checker, err := ParseAndCheck(t, testCase.code) + + if testCase.expectedError == nil { + require.NoError(t, err) + assert.Equal(t, + sema.MetaType, + RequireGlobalValue(t, checker.Elaboration, "result"), + ) + } else { + errs := ExpectCheckerErrors(t, err, 1) + assert.IsType(t, testCase.expectedError, errs[0]) + } + }) + } +} + +func TestCheckRestrictedTypeConstructor(t *testing.T) { + + t.Parallel() + + cases := []struct { + name string + code string + expectedError error + }{ + { + name: "S{I1, I2}", + code: ` + let result = RestrictedType(identifier: "S", restrictions: ["I1", "I2"]) + `, + expectedError: nil, + }, + { + name: "S{}", + code: ` + struct S {} + let result = RestrictedType(identifier: "S", restrictions: []) + `, + expectedError: nil, + }, + { + name: "{S}", + code: ` + struct S {} + let result = RestrictedType(identifier: nil, restrictions: ["S"]) + `, + expectedError: nil, + }, + { + name: "type mismatch first arg", + code: ` + let result = RestrictedType(identifier: 3, restrictions: ["I"]) + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "type mismatch second arg", + code: ` + let result = RestrictedType(identifier: "A", restrictions: [3]) + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "too many args", + code: ` + let result = RestrictedType(identifier: "A", restrictions: ["I1"], []) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "one arg", + code: ` + let result = RestrictedType(identifier: "A") + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "no args", + code: ` + let result = RestrictedType() + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "missing first label", + code: ` + let result = RestrictedType("S", restrictions: ["I1", "I2"]) + `, + expectedError: &sema.MissingArgumentLabelError{}, + }, + { + name: "missing second label", + code: ` + let result = RestrictedType(identifier: "S", ["I1", "I2"]) + `, + expectedError: &sema.MissingArgumentLabelError{}, + }, + } + + for _, testCase := range cases { + t.Run(testCase.name, func(t *testing.T) { + checker, err := ParseAndCheck(t, testCase.code) + + if testCase.expectedError == nil { + require.NoError(t, err) + assert.Equal(t, + &sema.OptionalType{Type: sema.MetaType}, + RequireGlobalValue(t, checker.Elaboration, "result"), + ) + } else { + errs := ExpectCheckerErrors(t, err, 1) + assert.IsType(t, testCase.expectedError, errs[0]) + } + }) + } +} + +func TestCheckCapabilityTypeConstructor(t *testing.T) { + + t.Parallel() + + cases := []struct { + name string + code string + expectedError error + }{ + { + name: "&String", + code: ` + let result = CapabilityType(Type<&String>()) + `, + expectedError: nil, + }, + { + name: "&Int", + code: ` + let result = CapabilityType(Type<&Int>()) + `, + expectedError: nil, + }, + { + name: "resource", + code: ` + resource R {} + let result = CapabilityType(Type<@R>()) + `, + expectedError: nil, + }, + { + name: "type mismatch", + code: ` + let result = CapabilityType(3) + `, + expectedError: &sema.TypeMismatchError{}, + }, + { + name: "too many args", + code: ` + let result = CapabilityType(Type(), Type()) + `, + expectedError: &sema.ArgumentCountError{}, + }, + { + name: "too few args", + code: ` + let result = CapabilityType() + `, + expectedError: &sema.ArgumentCountError{}, + }, + } + + for _, testCase := range cases { + t.Run(testCase.name, func(t *testing.T) { + checker, err := ParseAndCheck(t, testCase.code) + + if testCase.expectedError == nil { + require.NoError(t, err) + assert.Equal(t, + &sema.OptionalType{Type: sema.MetaType}, + RequireGlobalValue(t, checker.Elaboration, "result"), + ) + } else { + errs := ExpectCheckerErrors(t, err, 1) + assert.IsType(t, testCase.expectedError, errs[0]) + } + }) + } +} diff --git a/runtime/tests/interpreter/runtimetype_test.go b/runtime/tests/interpreter/runtimetype_test.go new file mode 100644 index 0000000000..9d4e730717 --- /dev/null +++ b/runtime/tests/interpreter/runtimetype_test.go @@ -0,0 +1,752 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright 2019-2021 Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package interpreter_test + +import ( + "testing" + + "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/sema" + "github.com/onflow/cadence/runtime/tests/utils" + + "github.com/stretchr/testify/assert" +) + +func TestInterpretOptionalType(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let a = OptionalType(Type()) + let b = OptionalType(Type()) + + resource R {} + let c = OptionalType(Type<@R>()) + let d = OptionalType(a) + + let e = Type() + `) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.OptionalStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + }, + }, + inter.Globals["a"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.OptionalStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, + }, + }, + inter.Globals["b"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.OptionalStaticType{ + Type: interpreter.CompositeStaticType{ + Location: utils.TestLocation, + QualifiedIdentifier: "R", + TypeID: "S.test.R", + }, + }, + }, + inter.Globals["c"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.OptionalStaticType{ + Type: interpreter.OptionalStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + }, + }, + }, + inter.Globals["d"].GetValue(), + ) + + assert.Equal(t, + inter.Globals["a"].GetValue(), + inter.Globals["e"].GetValue(), + ) +} + +func TestInterpretVariableSizedArrayType(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let a = VariableSizedArrayType(Type()) + let b = VariableSizedArrayType(Type()) + + resource R {} + let c = VariableSizedArrayType(Type<@R>()) + let d = VariableSizedArrayType(a) + + let e = Type<[String]>() + `) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + }, + }, + inter.Globals["a"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, + }, + }, + inter.Globals["b"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.VariableSizedStaticType{ + Type: interpreter.CompositeStaticType{ + Location: utils.TestLocation, + QualifiedIdentifier: "R", + TypeID: "S.test.R", + }, + }, + }, + inter.Globals["c"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.VariableSizedStaticType{ + Type: interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + }, + }, + }, + inter.Globals["d"].GetValue(), + ) + assert.Equal(t, + inter.Globals["a"].GetValue(), + inter.Globals["e"].GetValue(), + ) +} + +func TestInterpretConstantSizedArrayType(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let a = ConstantSizedArrayType(type: Type(), size: 10) + let b = ConstantSizedArrayType(type: Type(), size: 5) + + resource R {} + let c = ConstantSizedArrayType(type: Type<@R>(), size: 400) + let d = ConstantSizedArrayType(type: a, size: 6) + + let e = Type<[String; 10]>() + `) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.ConstantSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + Size: int64(10), + }, + }, + inter.Globals["a"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.ConstantSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, + Size: int64(5), + }, + }, + inter.Globals["b"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.ConstantSizedStaticType{ + Type: interpreter.CompositeStaticType{ + Location: utils.TestLocation, + QualifiedIdentifier: "R", + TypeID: "S.test.R", + }, + Size: int64(400), + }, + }, + inter.Globals["c"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.ConstantSizedStaticType{ + Type: interpreter.ConstantSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + Size: int64(10), + }, + Size: int64(6), + }, + }, + inter.Globals["d"].GetValue(), + ) + + assert.Equal(t, + inter.Globals["a"].GetValue(), + inter.Globals["e"].GetValue(), + ) +} + +func TestInterpretDictionaryType(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let a = DictionaryType(key: Type(), value: Type())! + let b = DictionaryType(key: Type(), value: Type())! + + resource R {} + let c = DictionaryType(key: Type(), value: Type<@R>())! + let d = DictionaryType(key: Type(), value: a)! + + let e = Type<{String: Int}>()! + + let f = DictionaryType(key: Type<[Bool]>(), value: Type()) + `) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.DictionaryStaticType{ + KeyType: interpreter.PrimitiveStaticTypeString, + ValueType: interpreter.PrimitiveStaticTypeInt, + }, + }, + inter.Globals["a"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.DictionaryStaticType{ + KeyType: interpreter.PrimitiveStaticTypeInt, + ValueType: interpreter.PrimitiveStaticTypeString, + }, + }, + inter.Globals["b"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.DictionaryStaticType{ + ValueType: interpreter.CompositeStaticType{ + Location: utils.TestLocation, + QualifiedIdentifier: "R", + TypeID: "S.test.R", + }, + KeyType: interpreter.PrimitiveStaticTypeInt, + }, + }, + inter.Globals["c"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.DictionaryStaticType{ + ValueType: interpreter.DictionaryStaticType{ + KeyType: interpreter.PrimitiveStaticTypeString, + ValueType: interpreter.PrimitiveStaticTypeInt, + }, + KeyType: interpreter.PrimitiveStaticTypeBool, + }, + }, + inter.Globals["d"].GetValue(), + ) + + assert.Equal(t, + inter.Globals["a"].GetValue(), + inter.Globals["e"].GetValue(), + ) + + assert.Equal(t, + interpreter.NilValue{}, + inter.Globals["f"].GetValue(), + ) +} + +func TestInterpretCompositeType(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + resource R {} + struct S {} + struct interface B {} + + let a = CompositeType("S.test.R")! + let b = CompositeType("S.test.S")! + let c = CompositeType("S.test.A") + let d = CompositeType("S.test.B") + + let e = Type<@R>() + + enum F: UInt8 {} + let f = CompositeType("S.test.F")! + let g = CompositeType("PublicKey")! + let h = CompositeType("HashAlgorithm")! + `) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.CompositeStaticType{ + QualifiedIdentifier: "R", + Location: utils.TestLocation, + TypeID: "S.test.R", + }, + }, + inter.Globals["a"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.CompositeStaticType{ + QualifiedIdentifier: "S", + Location: utils.TestLocation, + TypeID: "S.test.S", + }, + }, + inter.Globals["b"].GetValue(), + ) + + assert.Equal(t, + interpreter.NilValue{}, + inter.Globals["c"].GetValue(), + ) + + assert.Equal(t, + interpreter.NilValue{}, + inter.Globals["d"].GetValue(), + ) + + assert.Equal(t, + inter.Globals["a"].GetValue(), + inter.Globals["e"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.CompositeStaticType{ + QualifiedIdentifier: "F", + Location: utils.TestLocation, + TypeID: "S.test.F", + }, + }, + inter.Globals["f"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.CompositeStaticType{ + QualifiedIdentifier: "PublicKey", + Location: nil, + TypeID: "PublicKey", + }, + }, + inter.Globals["g"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.CompositeStaticType{ + QualifiedIdentifier: "HashAlgorithm", + Location: nil, + TypeID: "HashAlgorithm", + }, + }, + inter.Globals["h"].GetValue(), + ) +} + +func TestInterpretInterfaceType(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + resource interface R {} + struct interface S {} + struct B {} + + let a = InterfaceType("S.test.R")! + let b = InterfaceType("S.test.S")! + let c = InterfaceType("S.test.A") + let d = InterfaceType("S.test.B") + `) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.InterfaceStaticType{ + QualifiedIdentifier: "R", + Location: utils.TestLocation, + }, + }, + inter.Globals["a"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.InterfaceStaticType{ + QualifiedIdentifier: "S", + Location: utils.TestLocation, + }, + }, + inter.Globals["b"].GetValue(), + ) + + assert.Equal(t, + interpreter.NilValue{}, + inter.Globals["c"].GetValue(), + ) + + assert.Equal(t, + interpreter.NilValue{}, + inter.Globals["d"].GetValue(), + ) +} + +func TestInterpretFunctionType(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let a = FunctionType(parameters: [Type()], return: Type()) + let b = FunctionType(parameters: [Type(), Type()], return: Type()) + let c = FunctionType(parameters: [], return: Type()) + + let d = Type<((String): Int)>(); + `) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.FunctionStaticType{ + Type: &sema.FunctionType{ + Parameters: []*sema.Parameter{{TypeAnnotation: &sema.TypeAnnotation{Type: sema.StringType}}}, + ReturnTypeAnnotation: &sema.TypeAnnotation{Type: sema.IntType}, + }, + }, + }, + inter.Globals["a"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.FunctionStaticType{ + Type: &sema.FunctionType{ + Parameters: []*sema.Parameter{ + {TypeAnnotation: &sema.TypeAnnotation{Type: sema.StringType}}, + {TypeAnnotation: &sema.TypeAnnotation{Type: sema.IntType}}, + }, + ReturnTypeAnnotation: &sema.TypeAnnotation{Type: sema.BoolType}, + }, + }, + }, + inter.Globals["b"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.FunctionStaticType{ + Type: &sema.FunctionType{ + Parameters: []*sema.Parameter{}, + ReturnTypeAnnotation: &sema.TypeAnnotation{Type: sema.StringType}, + }, + }, + }, + inter.Globals["c"].GetValue(), + ) + + assert.Equal(t, + inter.Globals["a"].GetValue(), + inter.Globals["d"].GetValue(), + ) +} + +func TestInterpretReferenceType(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + resource R {} + struct S {} + + let a = ReferenceType(authorized: true, type: Type<@R>()) + let b = ReferenceType(authorized: false, type: Type()) + let c = ReferenceType(authorized: true, type: Type()) + let d = Type() + `) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.ReferenceStaticType{ + Type: interpreter.CompositeStaticType{ + QualifiedIdentifier: "R", + Location: utils.TestLocation, + TypeID: "S.test.R", + }, + Authorized: true, + }, + }, + inter.Globals["a"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.ReferenceStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + Authorized: false, + }, + }, + inter.Globals["b"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.ReferenceStaticType{ + Type: interpreter.CompositeStaticType{ + QualifiedIdentifier: "S", + Location: utils.TestLocation, + TypeID: "S.test.S", + }, + Authorized: true, + }, + }, + inter.Globals["c"].GetValue(), + ) + + assert.Equal(t, + inter.Globals["a"].GetValue(), + inter.Globals["d"].GetValue(), + ) +} + +func TestInterpretRestrictedType(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + resource interface R {} + struct interface S {} + resource A : R {} + struct B : S {} + + struct interface S2 { + pub let foo : Int + } + + let a = RestrictedType(identifier: "S.test.A", restrictions: ["S.test.R"])! + let b = RestrictedType(identifier: "S.test.B", restrictions: ["S.test.S"])! + + let c = RestrictedType(identifier: "S.test.B", restrictions: ["S.test.R"]) + let d = RestrictedType(identifier: "S.test.A", restrictions: ["S.test.S"]) + let e = RestrictedType(identifier: "S.test.B", restrictions: ["S.test.S2"]) + + let f = RestrictedType(identifier: "S.test.B", restrictions: ["X"]) + let g = RestrictedType(identifier: "S.test.N", restrictions: ["S.test.S2"]) + + let h = Type<@A{R}>() + let i = Type() + + let j = RestrictedType(identifier: nil, restrictions: ["S.test.R"])! + let k = RestrictedType(identifier: nil, restrictions: ["S.test.S"])! + `) + + assert.Equal(t, + interpreter.TypeValue{ + Type: &interpreter.RestrictedStaticType{ + Type: interpreter.CompositeStaticType{ + QualifiedIdentifier: "A", + Location: utils.TestLocation, + TypeID: "S.test.A", + }, + Restrictions: []interpreter.InterfaceStaticType{ + { + QualifiedIdentifier: "R", + Location: utils.TestLocation, + }, + }, + }, + }, + inter.Globals["a"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: &interpreter.RestrictedStaticType{ + Type: interpreter.CompositeStaticType{ + QualifiedIdentifier: "B", + Location: utils.TestLocation, + TypeID: "S.test.B", + }, + Restrictions: []interpreter.InterfaceStaticType{ + { + QualifiedIdentifier: "S", + Location: utils.TestLocation, + }, + }, + }, + }, + inter.Globals["b"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: &interpreter.RestrictedStaticType{ + Type: interpreter.PrimitiveStaticTypeAnyResource, + Restrictions: []interpreter.InterfaceStaticType{ + { + QualifiedIdentifier: "R", + Location: utils.TestLocation, + }, + }, + }, + }, + inter.Globals["j"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: &interpreter.RestrictedStaticType{ + Type: interpreter.PrimitiveStaticTypeAnyStruct, + Restrictions: []interpreter.InterfaceStaticType{ + { + QualifiedIdentifier: "S", + Location: utils.TestLocation, + }, + }, + }, + }, + inter.Globals["k"].GetValue(), + ) + + assert.Equal(t, + interpreter.NilValue{}, + inter.Globals["c"].GetValue(), + ) + + assert.Equal(t, + interpreter.NilValue{}, + inter.Globals["d"].GetValue(), + ) + + assert.Equal(t, + interpreter.NilValue{}, + inter.Globals["e"].GetValue(), + ) + + assert.Equal(t, + interpreter.NilValue{}, + inter.Globals["f"].GetValue(), + ) + + assert.Equal(t, + interpreter.NilValue{}, + inter.Globals["g"].GetValue(), + ) + + assert.Equal(t, + inter.Globals["a"].GetValue(), + inter.Globals["h"].GetValue(), + ) + + assert.Equal(t, + inter.Globals["b"].GetValue(), + inter.Globals["i"].GetValue(), + ) +} + +func TestInterpretCapabilityType(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let a = CapabilityType(Type<&String>())! + let b = CapabilityType(Type<&Int>())! + + resource R {} + let c = CapabilityType(Type<&R>())! + let d = CapabilityType(Type()) + + let e = Type>() + `) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.CapabilityStaticType{ + BorrowType: interpreter.ReferenceStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + Authorized: false, + }, + }, + }, + inter.Globals["a"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.CapabilityStaticType{ + BorrowType: interpreter.ReferenceStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, + Authorized: false, + }, + }, + }, + inter.Globals["b"].GetValue(), + ) + + assert.Equal(t, + interpreter.TypeValue{ + Type: interpreter.CapabilityStaticType{ + BorrowType: interpreter.ReferenceStaticType{ + Type: interpreter.CompositeStaticType{ + QualifiedIdentifier: "R", + Location: utils.TestLocation, + TypeID: "S.test.R", + }, + Authorized: false, + }, + }, + }, + inter.Globals["c"].GetValue(), + ) + + assert.Equal(t, + interpreter.NilValue{}, + inter.Globals["d"].GetValue(), + ) + + assert.Equal(t, + inter.Globals["a"].GetValue(), + inter.Globals["e"].GetValue(), + ) +}