From 4825f92c04d0135f5230db833074e88d7bf577a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Tue, 9 Jan 2024 14:35:41 -0800 Subject: [PATCH] improve type inference: use bound type parameters for argument expected type --- runtime/sema/check_invocation_expression.go | 42 ++++- runtime/sema/checker.go | 3 +- runtime/sema/simple_type.go | 8 +- runtime/sema/type.go | 178 +++++++++++++++--- runtime/stdlib/test_test.go | 18 +- runtime/tests/checker/account_test.go | 20 +- .../tests/checker/builtinfunctions_test.go | 26 ++- runtime/tests/checker/genericfunction_test.go | 10 +- runtime/tests/checker/range_value_test.go | 8 +- runtime/tests/checker/type_inference_test.go | 156 ++++++++++++--- 10 files changed, 356 insertions(+), 113 deletions(-) diff --git a/runtime/sema/check_invocation_expression.go b/runtime/sema/check_invocation_expression.go index db4d5c64c9..df25379fad 100644 --- a/runtime/sema/check_invocation_expression.go +++ b/runtime/sema/check_invocation_expression.go @@ -578,21 +578,46 @@ func (checker *Checker) checkInvocationRequiredArgument( var argumentType Type - if len(functionType.TypeParameters) == 0 { - // If the function doesn't use generic types, then the - // param types can be used to infer the types for arguments. + typeParameterCount := len(functionType.TypeParameters) + + // If all type parameters have been bound to a type, + // then resolve the parameter type with the type arguments, + // and propose the parameter type as the expected type for the argument. + if typeParameters.Len() == typeParameterCount { + + // Optimization: only resolve if there are type parameters. + // This avoids unnecessary work for non-generic functions. + if typeParameterCount > 0 { + parameterType = parameterType.Resolve(typeParameters) + // If the type parameter could not be resolved, use the invalid type. + if parameterType == nil { + parameterType = InvalidType + } + } + argumentType = checker.VisitExpression(argument.Expression, parameterType) + } else { - // TODO: pass the expected type to support for parameters + // If there are still type parameters that have not been bound to a type, + // then check the argument without an expected type. + // + // We will then have to manually check that the argument type is compatible + // with the parameter type (see below). + argumentType = checker.VisitExpression(argument.Expression, nil) // Try to unify the parameter type with the argument type. // If unification fails, fall back to the parameter type for now. - argumentRange := ast.NewRangeFromPositioned(checker.memoryGauge, argument.Expression) - - if parameterType.Unify(argumentType, typeParameters, checker.report, argumentRange) { + if parameterType.Unify( + argumentType, + typeParameters, + checker.report, + checker.memoryGauge, + argument.Expression, + ) { parameterType = parameterType.Resolve(typeParameters) + // If the type parameter could not be resolved, use the invalid type. if parameterType == nil { parameterType = InvalidType } @@ -600,7 +625,6 @@ func (checker *Checker) checkInvocationRequiredArgument( // Check that the type of the argument matches the type of the parameter. - // TODO: remove this once type inferring support for parameters is added checker.checkInvocationArgumentParameterTypeCompatibility( argument.Expression, argumentType, @@ -695,7 +719,7 @@ func (checker *Checker) checkAndBindGenericTypeParameterTypeArguments( // If the type parameter corresponding to the type argument has a type bound, // then check that the argument is a subtype of the type bound. - err := typeParameter.checkTypeBound(ty, ast.NewRangeFromPositioned(checker.memoryGauge, rawTypeArgument)) + err := typeParameter.checkTypeBound(ty, checker.memoryGauge, rawTypeArgument) checker.report(err) // Bind the type argument to the type parameter diff --git a/runtime/sema/checker.go b/runtime/sema/checker.go index 2df2c3daf4..c1331d1044 100644 --- a/runtime/sema/checker.go +++ b/runtime/sema/checker.go @@ -2475,7 +2475,8 @@ func (checker *Checker) convertInstantiationType(t *ast.InstantiationType) Type err := typeParameter.checkTypeBound( typeArgument, - ast.NewRangeFromPositioned(checker.memoryGauge, rawTypeArgument), + checker.memoryGauge, + rawTypeArgument, ) checker.report(err) } diff --git a/runtime/sema/simple_type.go b/runtime/sema/simple_type.go index d102815a9c..9f79f305aa 100644 --- a/runtime/sema/simple_type.go +++ b/runtime/sema/simple_type.go @@ -123,7 +123,13 @@ func (t *SimpleType) RewriteWithIntersectionTypes() (Type, bool) { return t, false } -func (*SimpleType) Unify(_ Type, _ *TypeParameterTypeOrderedMap, _ func(err error), _ ast.Range) bool { +func (*SimpleType) Unify( + _ Type, + _ *TypeParameterTypeOrderedMap, + _ func(err error), + _ common.MemoryGauge, + _ ast.HasPosition, +) bool { return false } diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 6c540641a7..f59e220c5c 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -191,7 +191,8 @@ type Type interface { other Type, typeParameters *TypeParameterTypeOrderedMap, report func(err error), - outerRange ast.Range, + memoryGauge common.MemoryGauge, + outerRange ast.HasPosition, ) bool // Resolve returns a type that is free of generic types (see `GenericType`), @@ -758,7 +759,8 @@ func (t *OptionalType) Unify( other Type, typeParameters *TypeParameterTypeOrderedMap, report func(err error), - outerRange ast.Range, + memoryGauge common.MemoryGauge, + outerRange ast.HasPosition, ) bool { otherOptional, ok := other.(*OptionalType) @@ -766,7 +768,13 @@ func (t *OptionalType) Unify( return false } - return t.Type.Unify(otherOptional.Type, typeParameters, report, outerRange) + return t.Type.Unify( + otherOptional.Type, + typeParameters, + report, + memoryGauge, + outerRange, + ) } func (t *OptionalType) Resolve(typeArguments *TypeParameterTypeOrderedMap) Type { @@ -968,7 +976,8 @@ func (t *GenericType) Unify( other Type, typeParameters *TypeParameterTypeOrderedMap, report func(err error), - outerRange ast.Range, + memoryGauge common.MemoryGauge, + outerRange ast.HasPosition, ) bool { if unifiedType, ok := typeParameters.Get(t.TypeParameter); ok { @@ -983,7 +992,7 @@ func (t *GenericType) Unify( TypeParameter: t.TypeParameter, ExpectedType: unifiedType, ActualType: other, - Range: outerRange, + Range: ast.NewRangeFromPositioned(memoryGauge, outerRange), }, ) } @@ -996,7 +1005,7 @@ func (t *GenericType) Unify( // If the type parameter corresponding to the type argument has a type bound, // then check that the argument's type is a subtype of the type bound. - err := t.TypeParameter.checkTypeBound(other, outerRange) + err := t.TypeParameter.checkTypeBound(other, memoryGauge, outerRange) if err != nil { report(err) } @@ -1282,7 +1291,13 @@ func (t *NumericType) MaxInt() *big.Int { return t.maxInt } -func (*NumericType) Unify(_ Type, _ *TypeParameterTypeOrderedMap, _ func(err error), _ ast.Range) bool { +func (*NumericType) Unify( + _ Type, + _ *TypeParameterTypeOrderedMap, + _ func(err error), + _ common.MemoryGauge, + _ ast.HasPosition, +) bool { return false } @@ -1491,7 +1506,13 @@ func (t *FixedPointNumericType) Scale() uint { return t.scale } -func (*FixedPointNumericType) Unify(_ Type, _ *TypeParameterTypeOrderedMap, _ func(err error), _ ast.Range) bool { +func (*FixedPointNumericType) Unify( + _ Type, + _ *TypeParameterTypeOrderedMap, + _ func(err error), + _ common.MemoryGauge, + _ ast.HasPosition, +) bool { return false } @@ -2806,7 +2827,8 @@ func (t *VariableSizedType) Unify( other Type, typeParameters *TypeParameterTypeOrderedMap, report func(err error), - outerRange ast.Range, + memoryGauge common.MemoryGauge, + outerRange ast.HasPosition, ) bool { otherArray, ok := other.(*VariableSizedType) @@ -2814,7 +2836,13 @@ func (t *VariableSizedType) Unify( return false } - return t.Type.Unify(otherArray.Type, typeParameters, report, outerRange) + return t.Type.Unify( + otherArray.Type, + typeParameters, + report, + memoryGauge, + outerRange, + ) } func (t *VariableSizedType) Resolve(typeArguments *TypeParameterTypeOrderedMap) Type { @@ -2986,7 +3014,8 @@ func (t *ConstantSizedType) Unify( other Type, typeParameters *TypeParameterTypeOrderedMap, report func(err error), - outerRange ast.Range, + memoryGauge common.MemoryGauge, + outerRange ast.HasPosition, ) bool { otherArray, ok := other.(*ConstantSizedType) @@ -2998,7 +3027,13 @@ func (t *ConstantSizedType) Unify( return false } - return t.Type.Unify(otherArray.Type, typeParameters, report, outerRange) + return t.Type.Unify( + otherArray.Type, + typeParameters, + report, + memoryGauge, + outerRange, + ) } func (t *ConstantSizedType) Resolve(typeArguments *TypeParameterTypeOrderedMap) Type { @@ -3132,7 +3167,7 @@ func (p TypeParameter) Equal(other *TypeParameter) bool { return p.Optional == other.Optional } -func (p TypeParameter) checkTypeBound(ty Type, typeRange ast.Range) error { +func (p TypeParameter) checkTypeBound(ty Type, memoryGauge common.MemoryGauge, typeRange ast.HasPosition) error { if p.TypeBound == nil || p.TypeBound.IsInvalidType() || ty.IsInvalidType() { @@ -3144,7 +3179,7 @@ func (p TypeParameter) checkTypeBound(ty Type, typeRange ast.Range) error { return &TypeMismatchError{ ExpectedType: p.TypeBound, ActualType: ty, - Range: typeRange, + Range: ast.NewRangeFromPositioned(memoryGauge, typeRange), } } @@ -3654,7 +3689,8 @@ func (t *FunctionType) Unify( other Type, typeParameters *TypeParameterTypeOrderedMap, report func(err error), - outerRange ast.Range, + memoryGauge common.MemoryGauge, + outerRange ast.HasPosition, ) ( result bool, ) { @@ -3684,6 +3720,7 @@ func (t *FunctionType) Unify( otherParameter.TypeAnnotation.Type, typeParameters, report, + memoryGauge, outerRange, ) result = result || parameterUnified @@ -3695,6 +3732,7 @@ func (t *FunctionType) Unify( otherFunction.ReturnTypeAnnotation.Type, typeParameters, report, + memoryGauge, outerRange, ) @@ -4798,7 +4836,13 @@ func (t *CompositeType) RewriteWithIntersectionTypes() (result Type, rewritten b return t, false } -func (*CompositeType) Unify(_ Type, _ *TypeParameterTypeOrderedMap, _ func(err error), _ ast.Range) bool { +func (*CompositeType) Unify( + _ Type, + _ *TypeParameterTypeOrderedMap, + _ func(err error), + _ common.MemoryGauge, + _ ast.HasPosition, +) bool { // TODO: return false } @@ -5578,7 +5622,13 @@ func (t *InterfaceType) RewriteWithIntersectionTypes() (Type, bool) { } -func (*InterfaceType) Unify(_ Type, _ *TypeParameterTypeOrderedMap, _ func(err error), _ ast.Range) bool { +func (*InterfaceType) Unify( + _ Type, + _ *TypeParameterTypeOrderedMap, + _ func(err error), + _ common.MemoryGauge, + _ ast.HasPosition, +) bool { // TODO: return false } @@ -6112,7 +6162,8 @@ func (t *DictionaryType) Unify( other Type, typeParameters *TypeParameterTypeOrderedMap, report func(err error), - outerRange ast.Range, + memoryGauge common.MemoryGauge, + outerRange ast.HasPosition, ) bool { otherDictionary, ok := other.(*DictionaryType) @@ -6120,8 +6171,22 @@ func (t *DictionaryType) Unify( return false } - keyUnified := t.KeyType.Unify(otherDictionary.KeyType, typeParameters, report, outerRange) - valueUnified := t.ValueType.Unify(otherDictionary.ValueType, typeParameters, report, outerRange) + keyUnified := t.KeyType.Unify( + otherDictionary.KeyType, + typeParameters, + report, + memoryGauge, + outerRange, + ) + + valueUnified := t.ValueType.Unify( + otherDictionary.ValueType, + typeParameters, + report, + memoryGauge, + outerRange, + ) + return keyUnified || valueUnified } @@ -6440,14 +6505,21 @@ func (t *InclusiveRangeType) Unify( other Type, typeParameters *TypeParameterTypeOrderedMap, report func(err error), - outerRange ast.Range, + memoryGauge common.MemoryGauge, + outerRange ast.HasPosition, ) bool { otherRange, ok := other.(*InclusiveRangeType) if !ok { return false } - return t.MemberType.Unify(otherRange.MemberType, typeParameters, report, outerRange) + return t.MemberType.Unify( + otherRange.MemberType, + typeParameters, + report, + memoryGauge, + outerRange, + ) } func (t *InclusiveRangeType) Resolve(typeArguments *TypeParameterTypeOrderedMap) Type { @@ -6726,14 +6798,21 @@ func (t *ReferenceType) Unify( other Type, typeParameters *TypeParameterTypeOrderedMap, report func(err error), - outerRange ast.Range, + memoryGauge common.MemoryGauge, + outerRange ast.HasPosition, ) bool { otherReference, ok := other.(*ReferenceType) if !ok { return false } - return t.Type.Unify(otherReference.Type, typeParameters, report, outerRange) + return t.Type.Unify( + otherReference.Type, + typeParameters, + report, + memoryGauge, + outerRange, + ) } func (t *ReferenceType) Resolve(typeArguments *TypeParameterTypeOrderedMap) Type { @@ -6848,7 +6927,13 @@ func (*AddressType) IsSuperType() bool { return false } -func (*AddressType) Unify(_ Type, _ *TypeParameterTypeOrderedMap, _ func(err error), _ ast.Range) bool { +func (*AddressType) Unify( + _ Type, + _ *TypeParameterTypeOrderedMap, + _ func(err error), + _ common.MemoryGauge, + _ ast.HasPosition, +) bool { return false } @@ -7480,7 +7565,13 @@ func (t *TransactionType) initializeMemberResolvers() { }) } -func (*TransactionType) Unify(_ Type, _ *TypeParameterTypeOrderedMap, _ func(err error), _ ast.Range) bool { +func (*TransactionType) Unify( + _ Type, + _ *TypeParameterTypeOrderedMap, + _ func(err error), + _ common.MemoryGauge, + _ ast.HasPosition, +) bool { return false } @@ -7762,7 +7853,13 @@ func (t *IntersectionType) SupportedEntitlements() (set *EntitlementOrderedSet) return set } -func (*IntersectionType) Unify(_ Type, _ *TypeParameterTypeOrderedMap, _ func(err error), _ ast.Range) bool { +func (*IntersectionType) Unify( + _ Type, + _ *TypeParameterTypeOrderedMap, + _ func(err error), + _ common.MemoryGauge, + _ ast.HasPosition, +) bool { // TODO: how do we unify the intersection sets? return false } @@ -7953,7 +8050,8 @@ func (t *CapabilityType) Unify( other Type, typeParameters *TypeParameterTypeOrderedMap, report func(err error), - outerRange ast.Range, + memoryGauge common.MemoryGauge, + outerRange ast.HasPosition, ) bool { otherCap, ok := other.(*CapabilityType) if !ok { @@ -7964,7 +8062,13 @@ func (t *CapabilityType) Unify( return false } - return t.BorrowType.Unify(otherCap.BorrowType, typeParameters, report, outerRange) + return t.BorrowType.Unify( + otherCap.BorrowType, + typeParameters, + report, + memoryGauge, + outerRange, + ) } func (t *CapabilityType) Resolve(typeArguments *TypeParameterTypeOrderedMap) Type { @@ -8513,7 +8617,13 @@ func (t *EntitlementType) RewriteWithIntersectionTypes() (Type, bool) { return t, false } -func (*EntitlementType) Unify(_ Type, _ *TypeParameterTypeOrderedMap, _ func(err error), _ ast.Range) bool { +func (*EntitlementType) Unify( + _ Type, + _ *TypeParameterTypeOrderedMap, + _ func(err error), + _ common.MemoryGauge, + _ ast.HasPosition, +) bool { return false } @@ -8664,7 +8774,13 @@ func (t *EntitlementMapType) RewriteWithIntersectionTypes() (Type, bool) { return t, false } -func (*EntitlementMapType) Unify(_ Type, _ *TypeParameterTypeOrderedMap, _ func(err error), _ ast.Range) bool { +func (*EntitlementMapType) Unify( + _ Type, + _ *TypeParameterTypeOrderedMap, + _ func(err error), + _ common.MemoryGauge, + _ ast.HasPosition, +) bool { return false } diff --git a/runtime/stdlib/test_test.go b/runtime/stdlib/test_test.go index c390d0616d..ae6a81585b 100644 --- a/runtime/stdlib/test_test.go +++ b/runtime/stdlib/test_test.go @@ -334,9 +334,9 @@ func TestTestNewMatcher(t *testing.T) { _, err := newTestContractInterpreter(t, script) - errs := checker.RequireCheckerErrors(t, err, 2) - assert.IsType(t, &sema.TypeParameterTypeMismatchError{}, errs[0]) - assert.IsType(t, &sema.TypeMismatchError{}, errs[1]) + errs := checker.RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) }) t.Run("combined matcher mismatching types", func(t *testing.T) { @@ -499,9 +499,9 @@ func TestTestEqualMatcher(t *testing.T) { _, err := newTestContractInterpreter(t, script) - errs := checker.RequireCheckerErrors(t, err, 2) - assert.IsType(t, &sema.TypeParameterTypeMismatchError{}, errs[0]) - assert.IsType(t, &sema.TypeMismatchError{}, errs[1]) + errs := checker.RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) }) t.Run("matcher or", func(t *testing.T) { @@ -1904,9 +1904,9 @@ func TestTestExpect(t *testing.T) { _, err := newTestContractInterpreter(t, script) - errs := checker.RequireCheckerErrors(t, err, 2) - assert.IsType(t, &sema.TypeParameterTypeMismatchError{}, errs[0]) - assert.IsType(t, &sema.TypeMismatchError{}, errs[1]) + errs := checker.RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) }) t.Run("resource with resource matcher", func(t *testing.T) { diff --git a/runtime/tests/checker/account_test.go b/runtime/tests/checker/account_test.go index f0cf63e6bf..046bc5f6f8 100644 --- a/runtime/tests/checker/account_test.go +++ b/runtime/tests/checker/account_test.go @@ -219,16 +219,14 @@ func TestCheckAccountStorageSave(t *testing.T) { if domain == common.PathDomainStorage { - errs := RequireCheckerErrors(t, err, 2) + errs := RequireCheckerErrors(t, err, 1) - require.IsType(t, &sema.TypeParameterTypeMismatchError{}, errs[0]) - require.IsType(t, &sema.TypeMismatchError{}, errs[1]) + require.IsType(t, &sema.TypeMismatchError{}, errs[0]) } else { - errs := RequireCheckerErrors(t, err, 3) + errs := RequireCheckerErrors(t, err, 2) - require.IsType(t, &sema.TypeParameterTypeMismatchError{}, errs[0]) + require.IsType(t, &sema.TypeMismatchError{}, errs[0]) require.IsType(t, &sema.TypeMismatchError{}, errs[1]) - require.IsType(t, &sema.TypeMismatchError{}, errs[2]) } }) @@ -254,16 +252,14 @@ func TestCheckAccountStorageSave(t *testing.T) { if domain == common.PathDomainStorage { - errs := RequireCheckerErrors(t, err, 2) + errs := RequireCheckerErrors(t, err, 1) - require.IsType(t, &sema.TypeParameterTypeMismatchError{}, errs[0]) - require.IsType(t, &sema.TypeMismatchError{}, errs[1]) + require.IsType(t, &sema.TypeMismatchError{}, errs[0]) } else { - errs := RequireCheckerErrors(t, err, 3) + errs := RequireCheckerErrors(t, err, 2) - require.IsType(t, &sema.TypeParameterTypeMismatchError{}, errs[0]) + require.IsType(t, &sema.TypeMismatchError{}, errs[0]) require.IsType(t, &sema.TypeMismatchError{}, errs[1]) - require.IsType(t, &sema.TypeMismatchError{}, errs[2]) } }) } diff --git a/runtime/tests/checker/builtinfunctions_test.go b/runtime/tests/checker/builtinfunctions_test.go index 1bff1ba334..eede313982 100644 --- a/runtime/tests/checker/builtinfunctions_test.go +++ b/runtime/tests/checker/builtinfunctions_test.go @@ -292,7 +292,7 @@ func TestCheckRevertibleRandom(t *testing.T) { } } - runCase := func(t *testing.T, ty sema.Type, code string) { + runValidCase := func(t *testing.T, ty sema.Type, code string) { checker, err := ParseAndCheckWithOptions(t, code, newOptions()) @@ -307,7 +307,7 @@ func TestCheckRevertibleRandom(t *testing.T) { t.Parallel() code := fmt.Sprintf("let rand = revertibleRandom<%s>()", ty) - runCase(t, ty, code) + runValidCase(t, ty, code) }) } @@ -316,7 +316,7 @@ func TestCheckRevertibleRandom(t *testing.T) { t.Parallel() code := fmt.Sprintf("let rand = revertibleRandom<%[1]s>(modulo: %[1]s(1))", ty) - runCase(t, ty, code) + runValidCase(t, ty, code) }) } @@ -394,7 +394,6 @@ func TestCheckRevertibleRandom(t *testing.T) { "modulo type mismatch", "let rand = revertibleRandom(modulo: UInt128(1))", []error{ - &sema.TypeParameterTypeMismatchError{}, &sema.TypeMismatchError{}, }, ) @@ -404,18 +403,17 @@ func TestCheckRevertibleRandom(t *testing.T) { "string modulo", `let rand = revertibleRandom(modulo: "abcd")`, []error{ - &sema.TypeParameterTypeMismatchError{}, &sema.TypeMismatchError{}, }, ) - // This is an error since we do not support type inference of function arguments. - runInvalidCase( - t, - "missing type inference", - "let rand = revertibleRandom(modulo: 1)", - []error{ - &sema.TypeParameterTypeMismatchError{}, - }, - ) + t.Run("type parameter used for argument", func(t *testing.T) { + t.Parallel() + + runValidCase( + t, + sema.UInt256Type, + "let rand = revertibleRandom(modulo: 1)", + ) + }) } diff --git a/runtime/tests/checker/genericfunction_test.go b/runtime/tests/checker/genericfunction_test.go index 00ccd2fc5a..22559c9acf 100644 --- a/runtime/tests/checker/genericfunction_test.go +++ b/runtime/tests/checker/genericfunction_test.go @@ -283,10 +283,9 @@ func TestCheckGenericFunctionInvocation(t *testing.T) { }, ) - errs := RequireCheckerErrors(t, err, 2) + errs := RequireCheckerErrors(t, err, 1) - assert.IsType(t, &sema.TypeParameterTypeMismatchError{}, errs[0]) - assert.IsType(t, &sema.TypeMismatchError{}, errs[1]) + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) }) t.Run("valid: one type parameter, one type argument, one parameter, one arguments", func(t *testing.T) { @@ -423,10 +422,9 @@ func TestCheckGenericFunctionInvocation(t *testing.T) { }, ) - errs := RequireCheckerErrors(t, err, 2) + errs := RequireCheckerErrors(t, err, 1) - assert.IsType(t, &sema.TypeParameterTypeMismatchError{}, errs[0]) - assert.IsType(t, &sema.TypeMismatchError{}, errs[1]) + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) }) t.Run("invalid: one type parameter, no type argument, no parameters, no arguments, return type", func(t *testing.T) { diff --git a/runtime/tests/checker/range_value_test.go b/runtime/tests/checker/range_value_test.go index 5bc94600b3..4e64dd56fc 100644 --- a/runtime/tests/checker/range_value_test.go +++ b/runtime/tests/checker/range_value_test.go @@ -330,13 +330,17 @@ func TestCheckInclusiveRangeConstructionInvalid(t *testing.T) { t, typeString, fmt.Sprintf("let r = InclusiveRange(%s(1), %s(2))", typeString, differentTypeString), - []error{&sema.TypeParameterTypeMismatchError{}, &sema.TypeMismatchError{}}, + []error{ + &sema.TypeMismatchError{}, + }, ) runInvalidCase( t, typeString, fmt.Sprintf("let r = InclusiveRange(%s(1), %s(10), step: %s(2))", typeString, typeString, differentTypeString), - []error{&sema.TypeParameterTypeMismatchError{}, &sema.TypeMismatchError{}}, + []error{ + &sema.TypeMismatchError{}, + }, ) // Not enough arguments diff --git a/runtime/tests/checker/type_inference_test.go b/runtime/tests/checker/type_inference_test.go index 6347f8083d..d8b2298cb7 100644 --- a/runtime/tests/checker/type_inference_test.go +++ b/runtime/tests/checker/type_inference_test.go @@ -335,7 +335,7 @@ func TestCheckFunctionArgumentTypeInference(t *testing.T) { require.NoError(t, err) }) - t.Run("with generics", func(t *testing.T) { + t.Run("with generics, void return type", func(t *testing.T) { t.Parallel() @@ -367,46 +367,146 @@ func TestCheckFunctionArgumentTypeInference(t *testing.T) { }, ) - errs := RequireCheckerErrors(t, err, 2) + require.NoError(t, err) + }) - require.IsType(t, &sema.TypeParameterTypeMismatchError{}, errs[0]) - typeParamMismatchErr := errs[0].(*sema.TypeParameterTypeMismatchError) - assert.Equal( - t, - &sema.VariableSizedType{ - Type: sema.Int8Type, - }, - typeParamMismatchErr.ExpectedType, - ) + t.Run("with generics, generic return type", func(t *testing.T) { - assert.Equal( - t, - &sema.VariableSizedType{ - Type: sema.IntType, + t.Parallel() + + typeParameter := &sema.TypeParameter{ + Name: "T", + TypeBound: nil, + } + + _, err := parseAndCheckWithTestValue(t, + ` + let res: [Int8] = test<[Int8]>([1, 2, 3]) + `, + &sema.FunctionType{ + TypeParameters: []*sema.TypeParameter{ + typeParameter, + }, + Parameters: []sema.Parameter{ + { + Label: sema.ArgumentLabelNotRequired, + Identifier: "value", + TypeAnnotation: sema.NewTypeAnnotation( + &sema.GenericType{ + TypeParameter: typeParameter, + }, + ), + }, + }, + ReturnTypeAnnotation: sema.NewTypeAnnotation( + &sema.GenericType{ + TypeParameter: typeParameter, + }, + ), }, - typeParamMismatchErr.ActualType, ) - require.IsType(t, &sema.TypeMismatchError{}, errs[1]) - typeMismatchErr := errs[1].(*sema.TypeMismatchError) + require.NoError(t, err) + }) + + t.Run("with generics, argument type propagation, simple", func(t *testing.T) { - assert.Equal( - t, - &sema.VariableSizedType{ - Type: sema.Int8Type, + t.Parallel() + + typeParameter := &sema.TypeParameter{ + Name: "T", + TypeBound: nil, + } + + _, err := parseAndCheckWithTestValue(t, + ` + let res: UInt8 = test(1 as UInt8, 2) + `, + &sema.FunctionType{ + TypeParameters: []*sema.TypeParameter{ + typeParameter, + }, + Parameters: []sema.Parameter{ + { + Label: sema.ArgumentLabelNotRequired, + Identifier: "a", + TypeAnnotation: sema.NewTypeAnnotation( + &sema.GenericType{ + TypeParameter: typeParameter, + }, + ), + }, + { + Label: sema.ArgumentLabelNotRequired, + Identifier: "b", + TypeAnnotation: sema.NewTypeAnnotation( + &sema.GenericType{ + TypeParameter: typeParameter, + }, + ), + }, + }, + ReturnTypeAnnotation: sema.NewTypeAnnotation( + &sema.GenericType{ + TypeParameter: typeParameter, + }, + ), }, - typeMismatchErr.ExpectedType, ) - assert.Equal( - t, - &sema.VariableSizedType{ - Type: sema.IntType, + require.NoError(t, err) + }) + + t.Run("with generics, argument type propagation, nested", func(t *testing.T) { + + t.Parallel() + + typeParameter := &sema.TypeParameter{ + Name: "T", + TypeBound: nil, + } + + _, err := parseAndCheckWithTestValue(t, + ` + let res: UInt8 = test(1 as UInt8, [2]) + `, + &sema.FunctionType{ + TypeParameters: []*sema.TypeParameter{ + typeParameter, + }, + Parameters: []sema.Parameter{ + { + Label: sema.ArgumentLabelNotRequired, + Identifier: "a", + TypeAnnotation: sema.NewTypeAnnotation( + &sema.GenericType{ + TypeParameter: typeParameter, + }, + ), + }, + { + Label: sema.ArgumentLabelNotRequired, + Identifier: "b", + TypeAnnotation: sema.NewTypeAnnotation( + &sema.VariableSizedType{ + Type: &sema.GenericType{ + TypeParameter: typeParameter, + }, + }, + ), + }, + }, + ReturnTypeAnnotation: sema.NewTypeAnnotation( + &sema.GenericType{ + TypeParameter: typeParameter, + }, + ), }, - typeMismatchErr.ActualType, ) + require.NoError(t, err) }) + } func TestCheckBinaryExpressionTypeInference(t *testing.T) {