From 495a4796ff224f70fcd7408a7818d9f9e627b827 Mon Sep 17 00:00:00 2001 From: Alex Vitkov <44268717+alexvitkov@users.noreply.github.com> Date: Fri, 18 Aug 2023 18:48:07 +0300 Subject: [PATCH] feat: add syntax for specifying function type environments (#2357) --- .../closure_explicit_types/Nargo.toml | 7 +++ .../closure_explicit_types/src/main.nr | 60 +++++++++++++++++++ crates/noirc_frontend/src/ast/mod.rs | 20 ++++++- .../src/hir/resolution/resolver.rs | 4 +- .../noirc_frontend/src/hir/type_check/expr.rs | 12 ++-- .../noirc_frontend/src/hir/type_check/mod.rs | 45 +++++++------- crates/noirc_frontend/src/hir_def/types.rs | 29 --------- .../src/monomorphization/mod.rs | 37 ++++++++---- crates/noirc_frontend/src/parser/parser.rs | 24 +++++++- 9 files changed, 158 insertions(+), 80 deletions(-) create mode 100644 crates/nargo_cli/tests/execution_success/closure_explicit_types/Nargo.toml create mode 100644 crates/nargo_cli/tests/execution_success/closure_explicit_types/src/main.nr diff --git a/crates/nargo_cli/tests/execution_success/closure_explicit_types/Nargo.toml b/crates/nargo_cli/tests/execution_success/closure_explicit_types/Nargo.toml new file mode 100644 index 00000000000..0ff85ab80bb --- /dev/null +++ b/crates/nargo_cli/tests/execution_success/closure_explicit_types/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "closure_explicit_types" +type = "bin" +authors = [""] +compiler_version = "0.10.3" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/execution_success/closure_explicit_types/src/main.nr b/crates/nargo_cli/tests/execution_success/closure_explicit_types/src/main.nr new file mode 100644 index 00000000000..1db36dcdd77 --- /dev/null +++ b/crates/nargo_cli/tests/execution_success/closure_explicit_types/src/main.nr @@ -0,0 +1,60 @@ + +fn ret_normal_lambda1() -> fn() -> Field { + || 10 +} + +// explicitly specified empty capture group +fn ret_normal_lambda2() -> fn[]() -> Field { + || 20 +} + +// return lamda that captures a thing +fn ret_closure1() -> fn[Field]() -> Field { + let x = 20; + || x + 10 +} + +// return lamda that captures two things +fn ret_closure2() -> fn[Field,Field]() -> Field { + let x = 20; + let y = 10; + || x + y + 10 +} + +// return lamda that captures two things with different types +fn ret_closure3() -> fn[u32,u64]() -> u64 { + let x: u32 = 20; + let y: u64 = 10; + || x as u64 + y + 10 +} + +// accepts closure that has 1 thing in its env, calls it and returns the result +fn accepts_closure1(f: fn[Field]() -> Field) -> Field { + f() +} + +// accepts closure that has 1 thing in its env and returns it +fn accepts_closure2(f: fn[Field]() -> Field) -> fn[Field]() -> Field { + f +} + +// accepts closure with different types in the capture group +fn accepts_closure3(f: fn[u32, u64]() -> u64) -> u64 { + f() +} + +fn main() { + assert(ret_normal_lambda1()() == 10); + assert(ret_normal_lambda2()() == 20); + assert(ret_closure1()() == 30); + assert(ret_closure2()() == 40); + assert(ret_closure3()() == 40); + + let x = 50; + assert(accepts_closure1(|| x) == 50); + assert(accepts_closure2(|| x + 10)() == 60); + + let y: u32 = 30; + let z: u64 = 40; + assert(accepts_closure3(|| y as u64 + z) == 70); +} \ No newline at end of file diff --git a/crates/noirc_frontend/src/ast/mod.rs b/crates/noirc_frontend/src/ast/mod.rs index 1934c3f790c..5bb3eea9db1 100644 --- a/crates/noirc_frontend/src/ast/mod.rs +++ b/crates/noirc_frontend/src/ast/mod.rs @@ -50,7 +50,11 @@ pub enum UnresolvedType { // Note: Tuples have no visibility, instead each of their elements may have one. Tuple(Vec), - Function(/*args:*/ Vec, /*ret:*/ Box), + Function( + /*args:*/ Vec, + /*ret:*/ Box, + /*env:*/ Box, + ), Unspecified, // This is for when the user declares a variable without specifying it's type Error, @@ -109,9 +113,19 @@ impl std::fmt::Display for UnresolvedType { Some(len) => write!(f, "str<{len}>"), }, FormatString(len, elements) => write!(f, "fmt<{len}, {elements}"), - Function(args, ret) => { + Function(args, ret, env) => { let args = vecmap(args, ToString::to_string); - write!(f, "fn({}) -> {ret}", args.join(", ")) + + match &**env { + UnresolvedType::Unit => { + write!(f, "fn({}) -> {ret}", args.join(", ")) + } + UnresolvedType::Tuple(env_types) => { + let env_types = vecmap(env_types, ToString::to_string); + write!(f, "fn[{}]({}) -> {ret}", env_types.join(", "), args.join(", ")) + } + _ => unreachable!(), + } } MutableReference(element) => write!(f, "&mut {element}"), Unit => write!(f, "()"), diff --git a/crates/noirc_frontend/src/hir/resolution/resolver.rs b/crates/noirc_frontend/src/hir/resolution/resolver.rs index 26c99d436cc..ad638684206 100644 --- a/crates/noirc_frontend/src/hir/resolution/resolver.rs +++ b/crates/noirc_frontend/src/hir/resolution/resolver.rs @@ -361,10 +361,10 @@ impl<'a> Resolver<'a> { UnresolvedType::Tuple(fields) => { Type::Tuple(vecmap(fields, |field| self.resolve_type_inner(field, new_variables))) } - UnresolvedType::Function(args, ret) => { + UnresolvedType::Function(args, ret, env) => { let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); let ret = Box::new(self.resolve_type_inner(*ret, new_variables)); - let env = Box::new(Type::Unit); + let env = Box::new(self.resolve_type_inner(*env, new_variables)); Type::Function(args, ret, env) } UnresolvedType::MutableReference(element) => { diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 20ecbaa3108..9f00f4b61da 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -837,13 +837,11 @@ impl<'interner> TypeChecker<'interner> { } for (param, (arg, _, arg_span)) in fn_params.iter().zip(callsite_args) { - if arg.try_unify_allow_incompat_lambdas(param).is_err() { - self.errors.push(TypeCheckError::TypeMismatch { - expected_typ: param.to_string(), - expr_typ: arg.to_string(), - expr_span: *arg_span, - }); - } + self.unify(arg, param, || TypeCheckError::TypeMismatch { + expected_typ: param.to_string(), + expr_typ: arg.to_string(), + expr_span: *arg_span, + }); } fn_ret.clone() diff --git a/crates/noirc_frontend/src/hir/type_check/mod.rs b/crates/noirc_frontend/src/hir/type_check/mod.rs index 608dacbbcc7..b1ef147d84c 100644 --- a/crates/noirc_frontend/src/hir/type_check/mod.rs +++ b/crates/noirc_frontend/src/hir/type_check/mod.rs @@ -63,33 +63,28 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec Result<(), UnificationError> { - use Type::*; - use TypeVariableKind::*; - - match (self, other) { - (TypeVariable(binding, Normal), other) | (other, TypeVariable(binding, Normal)) => { - if let TypeBinding::Bound(link) = &*binding.borrow() { - return link.try_unify_allow_incompat_lambdas(other); - } - - other.try_bind_to(binding) - } - (Function(params_a, ret_a, _), Function(params_b, ret_b, _)) => { - if params_a.len() == params_b.len() { - for (a, b) in params_a.iter().zip(params_b.iter()) { - a.try_unify_allow_incompat_lambdas(b)?; - } - - // no check for environments here! - ret_b.try_unify_allow_incompat_lambdas(ret_a) - } else { - Err(UnificationError) - } - } - _ => self.try_unify(other), - } - } - /// Similar to `unify` but if the check fails this will attempt to coerce the /// argument to the target type. When this happens, the given expression is wrapped in /// a new expression to convert its type. E.g. `array` -> `array.as_slice()` diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index 998f3093d49..92873c3268a 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -784,15 +784,27 @@ impl<'interner> Monomorphizer<'interner> { let is_closure = self.is_function_closure(call.func); if is_closure { - let extracted_func: ast::Expression; - let hir_call_func = self.interner.expression(&call.func); - if let HirExpression::Lambda(l) = hir_call_func { - let (setup, closure_variable) = self.lambda_with_setup(l, call.func); - block_expressions.push(setup); - extracted_func = closure_variable; - } else { - extracted_func = *original_func; - } + let local_id = self.next_local_id(); + + // store the function in a temporary variable before calling it + // this is needed for example if call.func is of the form `foo()()` + // without this, we would translate it to `foo().1(foo().0)` + let let_stmt = ast::Expression::Let(ast::Let { + id: local_id, + mutable: false, + name: "tmp".to_string(), + expression: Box::new(*original_func), + }); + block_expressions.push(let_stmt); + + let extracted_func = ast::Expression::Ident(ast::Ident { + location: None, + definition: Definition::Local(local_id), + mutable: false, + name: "tmp".to_string(), + typ: Self::convert_type(&self.interner.id_type(call.func)), + }); + func = Box::new(ast::Expression::ExtractTupleField( Box::new(extracted_func.clone()), 1usize, @@ -1435,7 +1447,7 @@ mod tests { #[test] fn simple_closure_with_no_captured_variables() { let src = r#" - fn main() -> Field { + fn main() -> pub Field { let x = 1; let closure = || x; closure() @@ -1451,7 +1463,10 @@ mod tests { }; closure_variable$l2 }; - closure$l3.1(closure$l3.0) + { + let tmp$4 = closure$l3; + tmp$l4.1(tmp$l4.0) + } } fn lambda$f1(mut env$l1: (Field)) -> Field { env$l1.0 diff --git a/crates/noirc_frontend/src/parser/parser.rs b/crates/noirc_frontend/src/parser/parser.rs index 6824446dbfe..458dfd352f7 100644 --- a/crates/noirc_frontend/src/parser/parser.rs +++ b/crates/noirc_frontend/src/parser/parser.rs @@ -971,12 +971,30 @@ fn function_type(type_parser: T) -> impl NoirParser where T: NoirParser, { - let args = parenthesized(type_parser.clone().separated_by(just(Token::Comma)).allow_trailing()); + let types = type_parser.clone().separated_by(just(Token::Comma)).allow_trailing(); + let args = parenthesized(types.clone()); + + let env = just(Token::LeftBracket) + .ignore_then(types) + .then_ignore(just(Token::RightBracket)) + .or_not() + .map(|args| match args { + Some(args) => { + if args.is_empty() { + UnresolvedType::Unit + } else { + UnresolvedType::Tuple(args) + } + } + None => UnresolvedType::Unit, + }); + keyword(Keyword::Fn) - .ignore_then(args) + .ignore_then(env) + .then(args) .then_ignore(just(Token::Arrow)) .then(type_parser) - .map(|(args, ret)| UnresolvedType::Function(args, Box::new(ret))) + .map(|((env, args), ret)| UnresolvedType::Function(args, Box::new(ret), Box::new(env))) } fn mutable_reference_type(type_parser: T) -> impl NoirParser