diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index b31b8f36031..65e94c4fcf4 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -506,7 +506,7 @@ impl<'context> Elaborator<'context> { unseen_fields.remove(&field_name); seen_fields.insert(field_name.clone()); - self.unify_with_coercions(&field_type, expected_type, resolved, || { + self.unify_with_coercions(&field_type, expected_type, resolved, field_span, || { TypeCheckError::TypeMismatch { expected_typ: expected_type.to_string(), expr_typ: field_type.to_string(), diff --git a/compiler/noirc_frontend/src/elaborator/statements.rs b/compiler/noirc_frontend/src/elaborator/statements.rs index 48380383eb0..da4492eb211 100644 --- a/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/compiler/noirc_frontend/src/elaborator/statements.rs @@ -78,7 +78,7 @@ impl<'context> Elaborator<'context> { let r#type = if annotated_type != Type::Error { // Now check if LHS is the same type as the RHS // Importantly, we do not coerce any types implicitly - self.unify_with_coercions(&expr_type, &annotated_type, expression, || { + self.unify_with_coercions(&expr_type, &annotated_type, expression, expr_span, || { TypeCheckError::TypeMismatch { expected_typ: annotated_type.to_string(), expr_typ: expr_type.to_string(), @@ -136,7 +136,7 @@ impl<'context> Elaborator<'context> { self.push_err(TypeCheckError::VariableMustBeMutable { name, span }); } - self.unify_with_coercions(&expr_type, &lvalue_type, expression, || { + self.unify_with_coercions(&expr_type, &lvalue_type, expression, span, || { TypeCheckError::TypeMismatchWithSource { actual: expr_type.clone(), expected: lvalue_type.clone(), diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 58d019c86aa..80c5aad7fd2 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -641,10 +641,18 @@ impl<'context> Elaborator<'context> { actual: &Type, expected: &Type, expression: ExprId, + span: Span, make_error: impl FnOnce() -> TypeCheckError, ) { let mut errors = Vec::new(); - actual.unify_with_coercions(expected, expression, self.interner, &mut errors, make_error); + actual.unify_with_coercions( + expected, + expression, + span, + self.interner, + &mut errors, + make_error, + ); self.errors.extend(errors.into_iter().map(|error| (error.into(), self.file))); } @@ -736,10 +744,12 @@ impl<'context> Elaborator<'context> { } for (param, (arg, arg_expr_id, arg_span)) in fn_params.iter().zip(callsite_args) { - self.unify_with_coercions(arg, param, *arg_expr_id, || TypeCheckError::TypeMismatch { - expected_typ: param.to_string(), - expr_typ: arg.to_string(), - expr_span: *arg_span, + self.unify_with_coercions(arg, param, *arg_expr_id, *arg_span, || { + TypeCheckError::TypeMismatch { + expected_typ: param.to_string(), + expr_typ: arg.to_string(), + expr_span: *arg_span, + } }); } @@ -1449,7 +1459,7 @@ impl<'context> Elaborator<'context> { }); } } else { - self.unify_with_coercions(&body_type, declared_return_type, body_id, || { + self.unify_with_coercions(&body_type, declared_return_type, body_id, func_span, || { let mut error = TypeCheckError::TypeMismatchWithSource { expected: declared_return_type.clone(), actual: body_type.clone(), diff --git a/compiler/noirc_frontend/src/hir/type_check/errors.rs b/compiler/noirc_frontend/src/hir/type_check/errors.rs index 380753d8198..d51760f9d20 100644 --- a/compiler/noirc_frontend/src/hir/type_check/errors.rs +++ b/compiler/noirc_frontend/src/hir/type_check/errors.rs @@ -138,6 +138,8 @@ pub enum TypeCheckError { UnconstrainedSliceReturnToConstrained { span: Span }, #[error("Call to unconstrained function is unsafe and must be in an unconstrained function or unsafe block")] Unsafe { span: Span }, + #[error("Converting an unconstrained fn to a non-unconstrained fn is unsafe")] + UnsafeFn { span: Span }, #[error("Slices must have constant length")] NonConstantSliceLength { span: Span }, #[error("Only sized types may be used in the entry point to a program")] @@ -361,6 +363,9 @@ impl<'a> From<&'a TypeCheckError> for Diagnostic { TypeCheckError::Unsafe { span } => { Diagnostic::simple_warning(error.to_string(), String::new(), *span) } + TypeCheckError::UnsafeFn { span } => { + Diagnostic::simple_warning(error.to_string(), String::new(), *span) + } } } } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index d2f499950f1..d6d114c7075 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -205,6 +205,12 @@ impl ResolvedGeneric { } } +enum FunctionCoercionResult { + NoCoercion, + Coerced(Type), + UnconstrainedMismatch(Type), +} + impl std::hash::Hash for StructType { fn hash(&self, state: &mut H) { self.id.hash(state); @@ -1776,6 +1782,7 @@ impl Type { &self, expected: &Type, expression: ExprId, + span: Span, interner: &mut NodeInterner, errors: &mut Vec, make_error: impl FnOnce() -> TypeCheckError, @@ -1792,17 +1799,24 @@ impl Type { } // Try to coerce `fn (..) -> T` to `unconstrained fn (..) -> T` - if let Some(coerced_self) = self.try_fn_to_unconstrained_fn_coercion(expected) { - coerced_self.unify_with_coercions(expected, expression, interner, errors, make_error); - return; - } + match self.try_fn_to_unconstrained_fn_coercion(expected) { + FunctionCoercionResult::NoCoercion => errors.push(make_error()), + FunctionCoercionResult::Coerced(coerced_self) => { + coerced_self + .unify_with_coercions(expected, expression, span, interner, errors, make_error); + } + FunctionCoercionResult::UnconstrainedMismatch(coerced_self) => { + errors.push(TypeCheckError::UnsafeFn { span }); - errors.push(make_error()); + coerced_self + .unify_with_coercions(expected, expression, span, interner, errors, make_error); + } + } } // If `self` and `expected` are function types, tries to coerce `self` to `expected`. // Returns None if no coercion can be applied, otherwise returns `self` coerced to `expected`. - fn try_fn_to_unconstrained_fn_coercion(&self, expected: &Type) -> Option { + fn try_fn_to_unconstrained_fn_coercion(&self, expected: &Type) -> FunctionCoercionResult { // If `self` and `expected` are function types, `self` can be coerced to `expected` // if `self` is unconstrained and `expected` is not. The other way around is an error, though. if let ( @@ -1810,10 +1824,15 @@ impl Type { Type::Function(_, _, _, unconstrained_expected), ) = (self.follow_bindings(), expected.follow_bindings()) { - (!unconstrained_self && unconstrained_expected) - .then(|| Type::Function(params, ret, env, unconstrained_expected)) + let coerced_type = Type::Function(params, ret, env, unconstrained_expected); + + match (unconstrained_self, unconstrained_expected) { + (true, true) | (false, false) => FunctionCoercionResult::NoCoercion, + (false, true) => FunctionCoercionResult::Coerced(coerced_type), + (true, false) => FunctionCoercionResult::UnconstrainedMismatch(coerced_type), + } } else { - None + FunctionCoercionResult::NoCoercion } } diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 26a82216eee..bba596ed19f 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -2562,16 +2562,8 @@ fn cannot_pass_unconstrained_function_to_regular_function() { let errors = get_program_errors(src); assert_eq!(errors.len(), 1); - if let CompilationError::TypeError(TypeCheckError::TypeMismatch { - expected_typ, - expr_typ, - .. - }) = &errors[0].0 - { - assert_eq!(expected_typ, "fn() -> ()"); - assert_eq!(expr_typ, "unconstrained fn() -> ()"); - } else { - panic!("Expected a type mismatch error, got {:?}", errors[0].0); + let CompilationError::TypeError(TypeCheckError::UnsafeFn { .. }) = &errors[0].0 else { + panic!("Expected an UnsafeFn error, got {:?}", errors[0].0); }; } @@ -2630,8 +2622,8 @@ fn cannot_pass_unconstrained_function_to_constrained_function() { let errors = get_program_errors(src); assert_eq!(errors.len(), 1); - let CompilationError::TypeError(TypeCheckError::TypeMismatch { .. }) = &errors[0].0 else { - panic!("Expected a type mismatch error, got {:?}", errors[0].0); + let CompilationError::TypeError(TypeCheckError::UnsafeFn { .. }) = &errors[0].0 else { + panic!("Expected an UnsafeFn error, got {:?}", errors[0].0); }; }