Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add syntax for specifying function type environments #2357

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "closure_explicit_types"
type = "bin"
authors = [""]
compiler_version = "0.10.3"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

fn ret_normal_lambda1() -> fn() -> Field {
|| 10
}

// explicitly specified empty capture group
fn ret_normal_lambda2() -> fn[]() -> Field {
|| 20
}

// return lamda that captures a thing
fn ret_closure1() -> fn[Field]() -> Field {
let x = 20;
|| x + 10
}

// return lamda that captures two things
fn ret_closure2() -> fn[Field,Field]() -> Field {
let x = 20;
let y = 10;
|| x + y + 10
}

// return lamda that captures two things with different types
fn ret_closure3() -> fn[u32,u64]() -> u64 {
let x: u32 = 20;
let y: u64 = 10;
|| x as u64 + y + 10
}

// accepts closure that has 1 thing in its env, calls it and returns the result
fn accepts_closure1(f: fn[Field]() -> Field) -> Field {
jfecher marked this conversation as resolved.
Show resolved Hide resolved
f()
}

// accepts closure that has 1 thing in its env and returns it
fn accepts_closure2(f: fn[Field]() -> Field) -> fn[Field]() -> Field {
f
}

// accepts closure with different types in the capture group
fn accepts_closure3(f: fn[u32, u64]() -> u64) -> u64 {
f()
}

fn main() {
assert(ret_normal_lambda1()() == 10);
assert(ret_normal_lambda2()() == 20);
assert(ret_closure1()() == 30);
assert(ret_closure2()() == 40);
assert(ret_closure3()() == 40);

let x = 50;
assert(accepts_closure1(|| x) == 50);
assert(accepts_closure2(|| x + 10)() == 60);

let y: u32 = 30;
let z: u64 = 40;
assert(accepts_closure3(|| y as u64 + z) == 70);
}
20 changes: 17 additions & 3 deletions crates/noirc_frontend/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ pub enum UnresolvedType {
// Note: Tuples have no visibility, instead each of their elements may have one.
Tuple(Vec<UnresolvedType>),

Function(/*args:*/ Vec<UnresolvedType>, /*ret:*/ Box<UnresolvedType>),
Function(
/*args:*/ Vec<UnresolvedType>,
/*ret:*/ Box<UnresolvedType>,
/*env:*/ Box<UnresolvedType>,
),

Unspecified, // This is for when the user declares a variable without specifying it's type
Error,
Expand Down Expand Up @@ -109,9 +113,19 @@ impl std::fmt::Display for UnresolvedType {
Some(len) => write!(f, "str<{len}>"),
},
FormatString(len, elements) => write!(f, "fmt<{len}, {elements}"),
Function(args, ret) => {
Function(args, ret, env) => {
let args = vecmap(args, ToString::to_string);
write!(f, "fn({}) -> {ret}", args.join(", "))

match &**env {
jfecher marked this conversation as resolved.
Show resolved Hide resolved
UnresolvedType::Unit => {
write!(f, "fn({}) -> {ret}", args.join(", "))
}
UnresolvedType::Tuple(env_types) => {
let env_types = vecmap(env_types, ToString::to_string);
write!(f, "fn[{}]({}) -> {ret}", env_types.join(", "), args.join(", "))
}
_ => unreachable!(),
}
}
MutableReference(element) => write!(f, "&mut {element}"),
Unit => write!(f, "()"),
Expand Down
4 changes: 2 additions & 2 deletions crates/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,10 @@ impl<'a> Resolver<'a> {
UnresolvedType::Tuple(fields) => {
Type::Tuple(vecmap(fields, |field| self.resolve_type_inner(field, new_variables)))
}
UnresolvedType::Function(args, ret) => {
UnresolvedType::Function(args, ret, env) => {
let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables));
let ret = Box::new(self.resolve_type_inner(*ret, new_variables));
let env = Box::new(Type::Unit);
let env = Box::new(self.resolve_type_inner(*env, new_variables));
Type::Function(args, ret, env)
}
UnresolvedType::MutableReference(element) => {
Expand Down
12 changes: 5 additions & 7 deletions crates/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -837,13 +837,11 @@ impl<'interner> TypeChecker<'interner> {
}

for (param, (arg, _, arg_span)) in fn_params.iter().zip(callsite_args) {
if arg.try_unify_allow_incompat_lambdas(param).is_err() {
self.errors.push(TypeCheckError::TypeMismatch {
expected_typ: param.to_string(),
expr_typ: arg.to_string(),
expr_span: *arg_span,
});
}
self.unify(arg, param, || TypeCheckError::TypeMismatch {
expected_typ: param.to_string(),
expr_typ: arg.to_string(),
expr_span: *arg_span,
});
}

fn_ret.clone()
Expand Down
45 changes: 20 additions & 25 deletions crates/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,33 +63,28 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
let (expr_span, empty_function) = function_info(interner, function_body_id);

let func_span = interner.expr_span(function_body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet
function_last_type.unify_with_coercions(
&declared_return_type,
*function_body_id,
interner,
&mut errors,
|| {
let mut error = TypeCheckError::TypeMismatchWithSource {
expected: declared_return_type.clone(),
actual: function_last_type.clone(),
span: func_span,
source: Source::Return(meta.return_type, expr_span),
};

let result = function_last_type.try_unify_allow_incompat_lambdas(&declared_return_type);

if result.is_err() {
function_last_type.unify_with_coercions(
&declared_return_type,
*function_body_id,
interner,
&mut errors,
|| {
let mut error = TypeCheckError::TypeMismatchWithSource {
expected: declared_return_type.clone(),
actual: function_last_type.clone(),
span: func_span,
source: Source::Return(meta.return_type, expr_span),
};

if empty_function {
error = error.add_context(
"implicitly returns `()` as its body has no tail or `return` expression",
);
}
if empty_function {
error = error.add_context(
"implicitly returns `()` as its body has no tail or `return` expression",
);
}

error
},
);
}
error
},
);
}

errors
Expand Down
29 changes: 0 additions & 29 deletions crates/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -947,35 +947,6 @@ impl Type {
}
}

/// Similar to try_unify() but allows non-matching capture groups for function types
pub fn try_unify_allow_incompat_lambdas(&self, other: &Type) -> Result<(), UnificationError> {
use Type::*;
use TypeVariableKind::*;

match (self, other) {
(TypeVariable(binding, Normal), other) | (other, TypeVariable(binding, Normal)) => {
if let TypeBinding::Bound(link) = &*binding.borrow() {
return link.try_unify_allow_incompat_lambdas(other);
}

other.try_bind_to(binding)
}
(Function(params_a, ret_a, _), Function(params_b, ret_b, _)) => {
if params_a.len() == params_b.len() {
for (a, b) in params_a.iter().zip(params_b.iter()) {
a.try_unify_allow_incompat_lambdas(b)?;
}

// no check for environments here!
ret_b.try_unify_allow_incompat_lambdas(ret_a)
} else {
Err(UnificationError)
}
}
_ => self.try_unify(other),
}
}

/// Similar to `unify` but if the check fails this will attempt to coerce the
/// argument to the target type. When this happens, the given expression is wrapped in
/// a new expression to convert its type. E.g. `array` -> `array.as_slice()`
Expand Down
37 changes: 26 additions & 11 deletions crates/noirc_frontend/src/monomorphization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -784,15 +784,27 @@ impl<'interner> Monomorphizer<'interner> {

let is_closure = self.is_function_closure(call.func);
if is_closure {
let extracted_func: ast::Expression;
let hir_call_func = self.interner.expression(&call.func);
if let HirExpression::Lambda(l) = hir_call_func {
let (setup, closure_variable) = self.lambda_with_setup(l, call.func);
block_expressions.push(setup);
extracted_func = closure_variable;
} else {
extracted_func = *original_func;
}
let local_id = self.next_local_id();

// store the function in a temporary variable before calling it
// this is needed for example if call.func is of the form `foo()()`
// without this, we would translate it to `foo().1(foo().0)`
let let_stmt = ast::Expression::Let(ast::Let {
id: local_id,
mutable: false,
name: "tmp".to_string(),
expression: Box::new(*original_func),
});
block_expressions.push(let_stmt);

let extracted_func = ast::Expression::Ident(ast::Ident {
location: None,
definition: Definition::Local(local_id),
mutable: false,
name: "tmp".to_string(),
typ: Self::convert_type(&self.interner.id_type(call.func)),
});

func = Box::new(ast::Expression::ExtractTupleField(
Box::new(extracted_func.clone()),
1usize,
Expand Down Expand Up @@ -1435,7 +1447,7 @@ mod tests {
#[test]
fn simple_closure_with_no_captured_variables() {
let src = r#"
fn main() -> Field {
fn main() -> pub Field {
let x = 1;
let closure = || x;
closure()
Expand All @@ -1451,7 +1463,10 @@ mod tests {
};
closure_variable$l2
};
closure$l3.1(closure$l3.0)
{
let tmp$4 = closure$l3;
tmp$l4.1(tmp$l4.0)
}
}
fn lambda$f1(mut env$l1: (Field)) -> Field {
env$l1.0
Expand Down
24 changes: 21 additions & 3 deletions crates/noirc_frontend/src/parser/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -971,12 +971,30 @@ fn function_type<T>(type_parser: T) -> impl NoirParser<UnresolvedType>
where
T: NoirParser<UnresolvedType>,
{
let args = parenthesized(type_parser.clone().separated_by(just(Token::Comma)).allow_trailing());
let types = type_parser.clone().separated_by(just(Token::Comma)).allow_trailing();
let args = parenthesized(types.clone());

let env = just(Token::LeftBracket)
.ignore_then(types)
.then_ignore(just(Token::RightBracket))
.or_not()
.map(|args| match args {
Some(args) => {
if args.is_empty() {
UnresolvedType::Unit
} else {
UnresolvedType::Tuple(args)
}
}
None => UnresolvedType::Unit,
});

keyword(Keyword::Fn)
.ignore_then(args)
.ignore_then(env)
.then(args)
.then_ignore(just(Token::Arrow))
.then(type_parser)
.map(|(args, ret)| UnresolvedType::Function(args, Box::new(ret)))
.map(|((env, args), ret)| UnresolvedType::Function(args, Box::new(ret), Box::new(env)))
}

fn mutable_reference_type<T>(type_parser: T) -> impl NoirParser<UnresolvedType>
Expand Down