From 56b7021de1357176d4489abf5f463641bffd9a4c Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Tue, 18 Jul 2023 16:39:44 +0300 Subject: [PATCH] feat: Initial work on rewriting closures to regular functions with hidden env This commit implements the following mechanism: On a line where a lambda expression is encountered, we initialize a tuple for the captured lambda environment and we rewrite the lambda to a regular function taking this environment as an additional parameter. All calls to the closure are then modified to insert this hidden parameter. In other words, the following code: ``` let x = some_value; let closure = |a| x + a; println(closure(10)); println(closure(20)); ``` is rewritten to: ``` fn closure(env: (Field,), a: Field) -> Field { env.0 + a } let x = some_value; let closure_env = (x,); println(closure(closure_env, 10)); println(closure(closure_env, 20)); ``` In the presence of nested closures, we propagate the captured variables implicitly through all intermediate closures: ``` let x = some_value; let closure = |a, c| # here, `x` is initialized from the hidden env of the outer closure let inner_closure = |b| a + b + x inner_closure(c) ``` To make these transforms possible, the following changes were made to the logic of the HIR resolver and the monomorphization pass: * In the HIR resolver pass, the code determines the precise list of variables captured by each lambda. Along with the list, we compute the index of each captured var within the parent closure's environment (when the capture is propagated). * Introduction of a new `Closure` type in order to be able to recognize the call-sites that need the automatic environment variable treatment. It's a bit unfortunate that the Closure type is defined within the `AST` modules that are used to describe the output of the monomorphization pass, because we aim to eliminate all closures during the pass. A better solution would have been possible if the type check pass after HIR resolution was outputting types specific to the HIR pass (then the closures would exist only within this separate non-simplified type system). * The majority of the work is in the Lambda processing step in the monomorphizer which performs the necessary transformations based on the above information. Remaining things to do: * There are a number of pending TODO items for various minor unresolved loose ends in the code. * There are a lot of possible additional tests to be written. * Update docs --- crates/noirc_driver/src/lib.rs | 1 - crates/noirc_evaluator/src/ssa/context.rs | 1 + .../src/hir/resolution/resolver.rs | 175 ++++++++++++++---- .../noirc_frontend/src/hir/type_check/expr.rs | 107 +++++++---- .../noirc_frontend/src/hir/type_check/mod.rs | 23 +++ crates/noirc_frontend/src/hir_def/expr.rs | 16 ++ crates/noirc_frontend/src/hir_def/types.rs | 17 ++ .../src/monomorphization/ast.rs | 54 +++++- .../src/monomorphization/mod.rs | 133 +++++++++++-- crates/noirc_frontend/src/node_interner.rs | 1 + 10 files changed, 446 insertions(+), 82 deletions(-) diff --git a/crates/noirc_driver/src/lib.rs b/crates/noirc_driver/src/lib.rs index 525c15af1e8..7c68ca24c17 100644 --- a/crates/noirc_driver/src/lib.rs +++ b/crates/noirc_driver/src/lib.rs @@ -322,7 +322,6 @@ pub fn compile_no_check( main_function: FuncId, ) -> Result { let program = monomorphize(main_function, &context.def_interner); - let (circuit, debug, abi) = if options.experimental_ssa { experimental_create_circuit(program, options.show_ssa, options.show_output)? } else { diff --git a/crates/noirc_evaluator/src/ssa/context.rs b/crates/noirc_evaluator/src/ssa/context.rs index 0fecb633db6..c00227a22ae 100644 --- a/crates/noirc_evaluator/src/ssa/context.rs +++ b/crates/noirc_evaluator/src/ssa/context.rs @@ -727,6 +727,7 @@ impl SsaContext { self.log(enable_logging, "reduce", "\ninlining:"); inline::inline_tree(self, self.first_block, &decision)?; + self.log(enable_logging, "reduce", "\nmerging paths:"); block::merge_path(self, self.first_block, BlockId::dummy(), None)?; //The CFG is now fully flattened, so we keep only the first block. let mut to_remove = Vec::new(); diff --git a/crates/noirc_frontend/src/hir/resolution/resolver.rs b/crates/noirc_frontend/src/hir/resolution/resolver.rs index ea0f341e983..99bcfb9259b 100644 --- a/crates/noirc_frontend/src/hir/resolution/resolver.rs +++ b/crates/noirc_frontend/src/hir/resolution/resolver.rs @@ -12,10 +12,10 @@ // // XXX: Resolver does not check for unused functions use crate::hir_def::expr::{ - HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCastExpression, - HirConstructorExpression, HirExpression, HirForExpression, HirIdent, HirIfExpression, - HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, HirMemberAccess, - HirMethodCallExpression, HirPrefixExpression, + HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCapturedVar, + HirCastExpression, HirConstructorExpression, HirExpression, HirForExpression, HirIdent, + HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, + HirMemberAccess, HirMethodCallExpression, HirPrefixExpression, }; use crate::token::Attribute; use std::collections::{HashMap, HashSet}; @@ -57,6 +57,11 @@ type Scope = GenericScope; type ScopeTree = GenericScopeTree; type ScopeForest = GenericScopeForest; +pub struct LambdaContext { + captures: Vec, + scope_index: usize, +} + /// The primary jobs of the Resolver are to validate that every variable found refers to exactly 1 /// definition in scope, and to convert the AST into the HIR. /// @@ -80,12 +85,10 @@ pub struct Resolver<'a> { /// were declared in. generics: Vec<(Rc, TypeVariable, Span)>, - /// Lambdas share the function scope of the function they're defined in, - /// so to identify whether they use any variables from the parent function - /// we keep track of the scope index a variable is declared in. When a lambda - /// is declared we push a scope and set this lambda_index to the scope index. - /// Any variable from a scope less than that must be from the parent function. - lambda_index: usize, + /// When resolving lambda expressions, we need to keep track of the variables + /// that are captured. We do this in order to create the hidden environment + /// parameter for the lambda function. + lambda_stack: Vec, } /// ResolverMetas are tagged onto each definition to track how many times they are used @@ -111,7 +114,7 @@ impl<'a> Resolver<'a> { self_type: None, generics: Vec::new(), errors: Vec::new(), - lambda_index: 0, + lambda_stack: Vec::new(), file, } } @@ -124,10 +127,6 @@ impl<'a> Resolver<'a> { self.errors.push(err); } - fn current_lambda_index(&self) -> usize { - self.scopes.current_scope_index() - } - /// Resolving a function involves interning the metadata /// interning any statements inside of the function /// and interning the function itself @@ -276,25 +275,25 @@ impl<'a> Resolver<'a> { // // If a variable is not found, then an error is logged and a dummy id // is returned, for better error reporting UX - fn find_variable_or_default(&mut self, name: &Ident) -> HirIdent { + fn find_variable_or_default(&mut self, name: &Ident) -> (HirIdent, usize) { self.find_variable(name).unwrap_or_else(|error| { self.push_err(error); let id = DefinitionId::dummy_id(); let location = Location::new(name.span(), self.file); - HirIdent { location, id } + (HirIdent { location, id }, 0) }) } - fn find_variable(&mut self, name: &Ident) -> Result { + fn find_variable(&mut self, name: &Ident) -> Result<(HirIdent, usize), ResolverError> { // Find the definition for this Ident let scope_tree = self.scopes.current_scope_tree(); let variable = scope_tree.find(&name.0.contents); let location = Location::new(name.span(), self.file); - if let Some((variable_found, _)) = variable { + if let Some((variable_found, scope)) = variable { variable_found.num_times_used += 1; let id = variable_found.ident.id; - Ok(HirIdent { location, id }) + Ok((HirIdent { location, id }, scope)) } else { Err(ResolverError::VariableNotDeclared { name: name.0.contents.clone(), @@ -478,24 +477,24 @@ impl<'a> Resolver<'a> { } } - fn get_ident_from_path(&mut self, path: Path) -> HirIdent { + fn get_ident_from_path(&mut self, path: Path) -> (HirIdent, usize) { let location = Location::new(path.span(), self.file); let error = match path.as_ident().map(|ident| self.find_variable(ident)) { - Some(Ok(ident)) => return ident, + Some(Ok(found)) => return found, // Try to look it up as a global, but still issue the first error if we fail Some(Err(error)) => match self.lookup_global(path) { - Ok(id) => return HirIdent { location, id }, + Ok(id) => return (HirIdent { location, id }, 0), Err(_) => error, }, None => match self.lookup_global(path) { - Ok(id) => return HirIdent { location, id }, + Ok(id) => return (HirIdent { location, id }, 0), Err(error) => error, }, }; self.push_err(error); let id = DefinitionId::dummy_id(); - HirIdent { location, id } + (HirIdent { location, id }, 0) } /// Translates an UnresolvedType to a Type @@ -774,12 +773,15 @@ impl<'a> Resolver<'a> { Self::find_numeric_generics_in_type(field, found); } } + Type::Function(parameters, return_type) => { for parameter in parameters { Self::find_numeric_generics_in_type(parameter, found); } Self::find_numeric_generics_in_type(return_type, found); } + Type::Closure(func) => Self::find_numeric_generics_in_type(func, found), + Type::Struct(struct_type, generics) => { for (i, generic) in generics.iter().enumerate() { if let Type::NamedGeneric(type_variable, name) = generic { @@ -841,7 +843,7 @@ impl<'a> Resolver<'a> { fn resolve_lvalue(&mut self, lvalue: LValue) -> HirLValue { match lvalue { LValue::Ident(ident) => { - HirLValue::Ident(self.find_variable_or_default(&ident), Type::Error) + HirLValue::Ident(self.find_variable_or_default(&ident).0, Type::Error) } LValue::MemberAccess { object, field_name } => { let object = Box::new(self.resolve_lvalue(*object)); @@ -889,7 +891,52 @@ impl<'a> Resolver<'a> { // Otherwise, then it is referring to an Identifier // This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10; // If the expression is a singular indent, we search the resolver's current scope as normal. - let hir_ident = self.get_ident_from_path(path); + let (hir_ident, var_scope_index) = self.get_ident_from_path(path); + + if hir_ident.id != DefinitionId::dummy_id() { + match self.interner.definition(hir_ident.id).kind { + DefinitionKind::Function(_) => {} + DefinitionKind::Global(_) => {} + DefinitionKind::GenericType(_) => {} + // We ignore the above definition kinds because only local variables can be captured by closures. + DefinitionKind::Local(_) => { + let mut transitive_capture_index: Option = None; + + for lambda_index in 0..self.lambda_stack.len() { + if self.lambda_stack[lambda_index].scope_index > var_scope_index { + // Beware: the same variable may be captured multiple times, so we check + // for its presence before adding the capture below. + let pos = self.lambda_stack[lambda_index] + .captures + .iter() + .position(|capture| capture.ident.id == hir_ident.id); + + if pos.is_none() { + self.lambda_stack[lambda_index].captures.push( + HirCapturedVar { + ident: hir_ident, + transitive_capture_index, + }, + ); + } + + if lambda_index + 1 < self.lambda_stack.len() { + // There is more than one closure between the current scope and + // the scope of the variable, so this is a propagated capture. + // We need to track the transitive capture index as we go up in + // the closure stack. + transitive_capture_index = Some(pos.unwrap_or( + // If this was a fresh capture, we added it to the end of + // the captures vector: + self.lambda_stack[lambda_index].captures.len() - 1, + )) + } + } + } + } + } + } + HirExpression::Ident(hir_ident) } ExpressionKind::Prefix(prefix) => { @@ -1010,8 +1057,10 @@ impl<'a> Resolver<'a> { // We must stay in the same function scope as the parent function to allow for closures // to capture variables. This is currently limited to immutable variables. ExpressionKind::Lambda(lambda) => self.in_new_scope(|this| { - let new_index = this.current_lambda_index(); - let old_index = std::mem::replace(&mut this.lambda_index, new_index); + let scope_index = this.scopes.current_scope_index(); + + this.lambda_stack + .push(LambdaContext { captures: Vec::new(), scope_index: scope_index }); let parameters = vecmap(lambda.parameters, |(pattern, typ)| { let parameter = DefinitionKind::Local(None); @@ -1021,8 +1070,14 @@ impl<'a> Resolver<'a> { let return_type = this.resolve_inferred_type(lambda.return_type); let body = this.resolve_expression(lambda.body); - this.lambda_index = old_index; - HirExpression::Lambda(HirLambda { parameters, return_type, body }) + let lambda_context = this.lambda_stack.pop().unwrap(); + + HirExpression::Lambda(HirLambda { + parameters, + return_type, + body, + captures: lambda_context.captures, + }) }), }; @@ -1299,6 +1354,7 @@ pub fn verify_mutable_reference(interner: &NodeInterner, rhs: ExprId) -> Result< #[cfg(test)] mod test { + use core::panic; use std::collections::HashMap; use fm::FileId; @@ -1322,7 +1378,9 @@ mod test { // and functions can be forward declared fn resolve_src_code(src: &str, func_namespace: Vec<&str>) -> Vec { let (program, errors) = parse_program(src); - assert!(errors.is_empty()); + if !errors.is_empty() { + panic!("Unexpected parse errors in test code: {:?}", errors); + } let mut interner = NodeInterner::default(); @@ -1530,6 +1588,59 @@ mod test { assert!(errors.is_empty()); } + #[test] + fn resolve_basic_closure() { + let src = r#" + fn main(x : Field) -> pub Field { + let closure = |y| y + x; + closure(x) + } + "#; + + let errors = resolve_src_code(src, vec!["main", "foo"]); + if !errors.is_empty() { + println!("Unexpected errors: {:?}", errors); + assert!(false); // there should be no errors + } + } + + #[test] + fn resolve_complex_closures() { + let src = r#" + fn main(x: Field) -> pub Field { + let closure_without_captures = |x| x + x; + let a = closure_without_captures(1); + + let closure_capturing_a_param = |y| y + x; + let b = closure_capturing_a_param(2); + + let closure_capturing_a_local_var = |y| y + b; + let c = closure_capturing_a_local_var(3); + + let closure_with_transitive_captures = |y| { + let d = 5; + let nested_closure = |z| { + let doubly_nested_closure = |w| w + x + b; + a + z + y + d + x + doubly_nested_closure(4) + x + y + }; + let res = nested_closure(5); + res + }; + + a + b + c + closure_with_transitive_captures(6) + } + "#; + + let errors = resolve_src_code(src, vec!["main", "foo"]); + if !errors.is_empty() { + println!("Unexpected errors: {:?}", errors); + assert!(false); // there should be no errors + } + + // TODO: Create a more sophisticated set of search functions over the HIR, so we can check + // that the correct variables are captured in each closure + } + fn path_unresolved_error(err: ResolverError, expected_unresolved_path: &str) { match err { ResolverError::PathResolutionError(PathResolutionError::Unresolved(name)) => { diff --git a/crates/noirc_frontend/src/hir/type_check/expr.rs b/crates/noirc_frontend/src/hir/type_check/expr.rs index 2ea9a33d191..91690dab226 100644 --- a/crates/noirc_frontend/src/hir/type_check/expr.rs +++ b/crates/noirc_frontend/src/hir/type_check/expr.rs @@ -241,11 +241,19 @@ impl<'interner> TypeChecker<'interner> { Type::Tuple(vecmap(&elements, |elem| self.check_expression(elem))) } HirExpression::Lambda(lambda) => { - let params = vecmap(lambda.parameters, |(pattern, typ)| { - self.bind_pattern(&pattern, typ.clone()); + let captured_vars = vecmap(lambda.captures, |capture| { + let typ = self.interner.id_type(capture.ident.id); typ }); + let env_type = Type::Tuple(captured_vars); + let mut params = vec![env_type]; + + for (pattern, typ) in lambda.parameters { + self.bind_pattern(&pattern, typ.clone()); + params.push(typ); + } + let actual_return = self.check_expression(&lambda.body); let span = self.interner.expr_span(&lambda.body); @@ -256,7 +264,9 @@ impl<'interner> TypeChecker<'interner> { expr_span: span, } }); - Type::Function(params, Box::new(lambda.return_type)) + + let function_type = Type::Function(params, Box::new(lambda.return_type)); + Type::Closure(Box::new(function_type)) } }; @@ -761,6 +771,46 @@ impl<'interner> TypeChecker<'interner> { } } + fn bind_function_type_impl( + &mut self, + fn_params: &Vec, + fn_ret: &Type, + callsite_args: &Vec<(Type, Span)>, + span: Span, + skip_params: usize, + ) -> Type { + let real_fn_params_count = fn_params.len() - skip_params; + + if real_fn_params_count != callsite_args.len() { + let empty_or_s = if real_fn_params_count == 1 { "" } else { "s" }; + let was_or_were = if callsite_args.len() == 1 { "was" } else { "were" }; + + self.errors.push(TypeCheckError::Unstructured { + msg: format!( + "Function expects {} parameter{} but {} {} given", + real_fn_params_count, + empty_or_s, + callsite_args.len(), + was_or_were + ), + span, + }); + return Type::Error; + } + + for (param, (arg, arg_span)) in fn_params.iter().skip(skip_params).zip(callsite_args) { + arg.make_subtype_of(param, *arg_span, &mut self.errors, || { + TypeCheckError::TypeMismatch { + expected_typ: param.to_string(), + expr_typ: arg.to_string(), + expr_span: *arg_span, + } + }); + } + + fn_ret.clone() + } + fn bind_function_type(&mut self, function: Type, args: Vec<(Type, Span)>, span: Span) -> Type { // Could do a single unification for the entire function type, but matching beforehand // lets us issue a more precise error on the individual argument that fails to type check. @@ -777,36 +827,31 @@ impl<'interner> TypeChecker<'interner> { ret } - Type::Function(parameters, ret) => { - if parameters.len() != args.len() { - let empty_or_s = if parameters.len() == 1 { "" } else { "s" }; - let was_or_were = if args.len() == 1 { "was" } else { "were" }; - - self.errors.push(TypeCheckError::Unstructured { - msg: format!( - "Function expects {} parameter{} but {} {} given", - parameters.len(), - empty_or_s, - args.len(), - was_or_were - ), + Type::Closure(closure) => match closure.as_ref() { + Type::Function(parameters, ret) => { + if parameters.len() < 1 { + unreachable!("Closure type should always contain the captured variables tuple type: {}", closure.as_ref()); + } + self.bind_function_type_impl( + parameters.as_ref(), + ret.as_ref(), + args.as_ref(), span, - }); - return Type::Error; + 1, + ) } - - for (param, (arg, arg_span)) in parameters.iter().zip(args) { - arg.make_subtype_of(param, arg_span, &mut self.errors, || { - TypeCheckError::TypeMismatch { - expected_typ: param.to_string(), - expr_typ: arg.to_string(), - expr_span: arg_span, - } - }); - } - - *ret - } + _ => unreachable!( + "Closure type should always contain a function pointer type: {}", + closure.as_ref() + ), + }, + Type::Function(parameters, ret) => self.bind_function_type_impl( + parameters.as_ref(), + ret.as_ref(), + args.as_ref(), + span, + 0, + ), Type::Error => Type::Error, other => { self.errors.push(TypeCheckError::Unstructured { diff --git a/crates/noirc_frontend/src/hir/type_check/mod.rs b/crates/noirc_frontend/src/hir/type_check/mod.rs index b8bb6c788e9..580e0c7b8ba 100644 --- a/crates/noirc_frontend/src/hir/type_check/mod.rs +++ b/crates/noirc_frontend/src/hir/type_check/mod.rs @@ -142,6 +142,7 @@ impl<'interner> TypeChecker<'interner> { #[cfg(test)] mod test { use std::collections::HashMap; + use std::vec; use fm::FileId; use iter_extended::vecmap; @@ -303,7 +304,29 @@ mod test { type_check_src_code(src, vec![String::from("main"), String::from("foo")]); } + #[test] + fn basic_closure() { + let src = r#" + fn main(x : Field) -> pub Field { + let closure = |y| y + x; + closure(x) + } + "#; + + type_check_src_code(src, vec![String::from("main"), String::from("foo")]); + } + #[test] + fn closure_with_no_args() { + let src = r#" + fn main(x : Field) -> pub Field { + let closure = || x; + closure() + } + "#; + + type_check_src_code(src, vec![String::from("main")]); + } // This is the same Stub that is in the resolver, maybe we can pull this out into a test module and re-use? struct TestPathResolver(HashMap); diff --git a/crates/noirc_frontend/src/hir_def/expr.rs b/crates/noirc_frontend/src/hir_def/expr.rs index b9ee6634cc7..c9d09b43d69 100644 --- a/crates/noirc_frontend/src/hir_def/expr.rs +++ b/crates/noirc_frontend/src/hir_def/expr.rs @@ -191,9 +191,25 @@ impl HirBlockExpression { } } +/// A variable captured inside a closure +#[derive(Debug, Clone)] +pub struct HirCapturedVar { + pub ident: HirIdent, + + /// This will be None when the capture refers to a local variable declared + /// in the same scope as the closure. In a closure-inside-another-closure + /// scenarios, we might have a transitive captures of variables that must + /// be propagated during the construction of each closure. In this case, + /// we store the index of the captured variable in the environment of our + /// direct parent closure. We do this in order to simplify the HIR to AST + /// transformation in the monomorphization pass. + pub transitive_capture_index: Option, +} + #[derive(Debug, Clone)] pub struct HirLambda { pub parameters: Vec<(HirPattern, Type)>, pub return_type: Type, pub body: ExprId, + pub captures: Vec, } diff --git a/crates/noirc_frontend/src/hir_def/types.rs b/crates/noirc_frontend/src/hir_def/types.rs index 9441307bf28..e0feeb8746c 100644 --- a/crates/noirc_frontend/src/hir_def/types.rs +++ b/crates/noirc_frontend/src/hir_def/types.rs @@ -74,6 +74,11 @@ pub enum Type { /// A functions with arguments, and a return type. Function(Vec, Box), + /// A closure (a pair of a function pointer and a tuple of captured variables). + /// Stores the underlying function type, which has been modifies such that the + /// first parameter is the type of the captured variables tuple. + Closure(Box), + /// &mut T MutableReference(Box), @@ -591,6 +596,7 @@ impl Type { parameters.iter().any(|parameter| parameter.contains_numeric_typevar(target_id)) || return_type.contains_numeric_typevar(target_id) } + Type::Closure(func) => func.contains_numeric_typevar(target_id), Type::Struct(struct_type, generics) => { generics.iter().enumerate().any(|(i, generic)| { if named_generic_id_matches_target(generic) { @@ -668,6 +674,9 @@ impl std::fmt::Display for Type { let args = vecmap(args, ToString::to_string); write!(f, "fn({}) -> {}", args.join(", "), ret) } + Type::Closure(func) => { + write!(f, "closure {}", func) // i.e. we produce a string such as "closure fn(args) -> ret" + } Type::MutableReference(element) => { write!(f, "&mut {element}") } @@ -1226,6 +1235,7 @@ impl Type { Type::NamedGeneric(..) => unreachable!(), Type::Forall(..) => unreachable!(), Type::Function(_, _) => unreachable!(), + Type::Closure(_) => unreachable!(), Type::Slice(_) => unreachable!("slices cannot be used in the abi"), Type::MutableReference(_) => unreachable!("&mut cannot be used in the abi"), } @@ -1345,6 +1355,10 @@ impl Type { let ret = Box::new(ret.substitute(type_bindings)); Type::Function(args, ret) } + Type::Closure(func) => { + let func = Box::new(func.substitute(type_bindings)); + Type::Closure(func) + } Type::MutableReference(element) => { Type::MutableReference(Box::new(element.substitute(type_bindings))) } @@ -1378,6 +1392,7 @@ impl Type { Type::Function(args, ret) => { args.iter().any(|arg| arg.occurs(target_id)) || ret.occurs(target_id) } + Type::Closure(func) => func.occurs(target_id), Type::MutableReference(element) => element.occurs(target_id), Type::FieldElement(_) @@ -1421,6 +1436,8 @@ impl Type { let ret = Box::new(ret.follow_bindings()); Function(args, ret) } + Closure(func) => Closure(Box::new(func.follow_bindings())), + MutableReference(element) => MutableReference(Box::new(element.follow_bindings())), // Expect that this function should only be called on instantiated types diff --git a/crates/noirc_frontend/src/monomorphization/ast.rs b/crates/noirc_frontend/src/monomorphization/ast.rs index 7cac2ed8e4f..09e8e59f8e9 100644 --- a/crates/noirc_frontend/src/monomorphization/ast.rs +++ b/crates/noirc_frontend/src/monomorphization/ast.rs @@ -29,7 +29,6 @@ pub enum Expression { Tuple(Vec), ExtractTupleField(Box, usize), Call(Call), - Let(Let), Constrain(Box, Location), Assign(Assign), @@ -102,6 +101,13 @@ pub struct Binary { pub location: Location, } +#[derive(Debug, Clone)] +pub struct Lambda { + pub function: Ident, + pub env: Ident, + pub typ: Type, // TODO: Perhaps this is not necessary +} + #[derive(Debug, Clone)] pub struct If { pub condition: Box, @@ -223,6 +229,52 @@ impl Type { } } +pub fn type_of_lvalue(lvalue: &LValue) -> Type { + match lvalue { + LValue::Ident(ident) => ident.typ.clone(), + LValue::Index { element_type, .. } => element_type.clone(), + LValue::MemberAccess { object, field_index } => { + let tuple_type = type_of_lvalue(object.as_ref()); + match tuple_type { + Type::Tuple(fields) => fields[*field_index].clone(), + _ => unreachable!("ICE: Member access on non-tuple type"), + } + } + LValue::Dereference { element_type, .. } => element_type.clone(), + } +} + +pub fn type_of(expr: &Expression) -> Type { + match expr { + Expression::Ident(ident) => ident.typ.clone(), + Expression::Literal(lit) => match lit { + Literal::Integer(_, typ) => typ.clone(), + Literal::Bool(_) => Type::Bool, + Literal::Str(str) => Type::String(str.len() as u64), + Literal::Array(array) => { + Type::Array(array.contents.len() as u64, Box::new(array.element_type.clone())) + } + }, + Expression::Block(stmts) => type_of(stmts.last().unwrap()), + Expression::Unary(unary) => unary.result_type.clone(), + Expression::Binary(_binary) => unreachable!("TODO: How do we get the type of a Binary op"), + Expression::Index(index) => index.element_type.clone(), + Expression::Cast(cast) => cast.r#type.clone(), + Expression::For(_for_expr) => unreachable!("TODO: How do we get the type of a for loop?"), + Expression::If(if_expr) => if_expr.typ.clone(), + Expression::Tuple(elements) => Type::Tuple(elements.iter().map(type_of).collect()), + Expression::ExtractTupleField(tuple, index) => match tuple.as_ref() { + Expression::Tuple(fields) => type_of(&fields[*index]), + _ => unreachable!("ICE: Tuple field access on non-tuple type"), + }, + Expression::Call(call) => call.return_type.clone(), + Expression::Let(let_stmt) => type_of(let_stmt.expression.as_ref()), + Expression::Constrain(contraint, _) => type_of(contraint.as_ref()), + Expression::Assign(assign) => type_of_lvalue(&assign.lvalue), + Expression::Semi(expr) => type_of(expr.as_ref()), // TODO: Is this correct? + } +} + #[derive(Debug, Clone)] pub struct Program { pub functions: Vec, diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index d8dfb1550c0..d23c29b4644 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -30,6 +30,11 @@ use self::ast::{Definition, FuncId, Function, LocalId, Program}; pub mod ast; pub mod printer; +struct LambdaContext { + env_ident: Box, + captures: Vec, +} + /// The context struct for the monomorphization pass. /// /// This struct holds the FIFO queue of functions to monomorphize, which is added to @@ -58,6 +63,8 @@ struct Monomorphizer<'interner> { /// Used to reference existing definitions in the HIR interner: &'interner NodeInterner, + lambda_envs_stack: Vec, + next_local_id: u32, next_function_id: u32, } @@ -103,6 +110,7 @@ impl<'interner> Monomorphizer<'interner> { next_local_id: 0, next_function_id: 0, interner, + lambda_envs_stack: Vec::new(), } } @@ -604,6 +612,15 @@ impl<'interner> Monomorphizer<'interner> { ast::Expression::Block(definitions) } + /// Find a captured variable in the innermost closure + fn lookup_captured(&mut self, id: node_interner::DefinitionId) -> Option { + let ctx = self.lambda_envs_stack.last()?; + ctx.captures + .iter() + .position(|capture| capture.ident.id == id) + .map(|index| ast::Expression::ExtractTupleField(ctx.env_ident.clone(), index)) + } + /// A local (ie non-global) ident only fn local_ident(&mut self, ident: &HirIdent) -> Option { let definition = self.interner.definition(ident.id); @@ -631,10 +648,10 @@ impl<'interner> Monomorphizer<'interner> { ast::Expression::Ident(ident) } DefinitionKind::Global(expr_id) => self.expr(*expr_id), - DefinitionKind::Local(_) => { + DefinitionKind::Local(_) => self.lookup_captured(ident.id).unwrap_or_else(|| { let ident = self.local_ident(&ident).unwrap(); ast::Expression::Ident(ident) - } + }), DefinitionKind::GenericType(type_variable) => { let value = match &*type_variable.borrow() { TypeBinding::Unbound(_) => { @@ -707,6 +724,19 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::Function(args, ret) } + HirType::Closure(func) => { + match func.as_ref() { + HirType::Function(arguments, return_type) => { + let converted_args = vecmap(arguments, Self::convert_type); + let converted_ret = Box::new(Self::convert_type(&return_type)); + let fn_type = ast::Type::Function(converted_args, converted_ret); + let env_type = ast::Type::Tuple(vec![]); // TODO compute this + ast::Type::Tuple(vec![env_type, fn_type]) + } + _ => unreachable!("Unexpected closure type {}", func), + } + } + HirType::MutableReference(element) => { let element = Self::convert_type(element); ast::Type::MutableReference(Box::new(element)) @@ -739,17 +769,34 @@ impl<'interner> Monomorphizer<'interner> { } } + fn is_function_closure(&self, func: &ast::Expression) -> bool { + matches!(ast::type_of(func), ast::Type::Tuple(_)) + } + fn function_call( &mut self, call: HirCallExpression, id: node_interner::ExprId, ) -> ast::Expression { - let func = Box::new(self.expr(call.func)); - let arguments = vecmap(&call.arguments, |id| self.expr(*id)); + let original_func = Box::new(self.expr(call.func)); + let mut arguments = vecmap(&call.arguments, |id| self.expr(*id)); let return_type = self.interner.id_type(id); let return_type = Self::convert_type(&return_type); let location = call.location; + let is_closure = self.is_function_closure(&*original_func); + + let func = if is_closure { + Box::new(ast::Expression::ExtractTupleField(Box::new((*original_func).clone()), 1usize)) + } else { + original_func.clone() + }; + + if is_closure { + let env_argument = + ast::Expression::ExtractTupleField(Box::new((*original_func).clone()), 0usize); + arguments.insert(0, env_argument); + } self.try_evaluate_call(&func, &call.arguments, &return_type) .unwrap_or(ast::Expression::Call(ast::Call { func, arguments, return_type, location })) } @@ -953,27 +1000,79 @@ impl<'interner> Monomorphizer<'interner> { Param(pattern, typ, noirc_abi::AbiVisibility::Private) })); - let parameters = self.parameters(parameters); - let body = self.expr(lambda.body); + let converted_parameters = self.parameters(parameters); let id = self.next_function_id(); - let return_type = ret_type.clone(); let name = lambda_name.to_owned(); - let unconstrained = false; + let return_type = ret_type.clone(); - let function = ast::Function { id, name, parameters, body, return_type, unconstrained }; - self.push_function(id, function); + let env_local_id = self.next_local_id(); + let env_name = "env"; + let env_tuple = ast::Expression::Tuple(vecmap(&lambda.captures, |capture| { + match capture.transitive_capture_index { + Some(field_index) => match self.lambda_envs_stack.last() { + Some(lambda_ctx) => ast::Expression::ExtractTupleField( + lambda_ctx.env_ident.clone(), + field_index, + ), + None => unreachable!( + "Expected to find a parent closure environment, but found none" + ), + }, + None => { + let ident = self.local_ident(&capture.ident).unwrap(); + ast::Expression::Ident(ident) + } + } + })); + let env_typ = ast::type_of(&env_tuple); + + let env_let_stmt = ast::Expression::Let(ast::Let { + id: env_local_id, + mutable: true, + name: env_name.to_string(), + expression: Box::new(env_tuple), + }); - let typ = ast::Type::Function(parameter_types, Box::new(ret_type)); + let location = None; // TODO: This should match the location of the lambda expression + let mutable = false; + let definition = Definition::Local(env_local_id); - let name = lambda_name.to_owned(); - ast::Expression::Ident(ast::Ident { + let env_ident = ast::Expression::Ident(ast::Ident { + location, + mutable, + definition, + name: env_name.to_string(), + typ: env_typ.clone(), + }); + + // TODO: Is this costly? Can we avoid the copies somehow? + self.lambda_envs_stack.push(LambdaContext { + env_ident: Box::new(env_ident.clone()), + captures: lambda.captures, + }); + let body = self.expr(lambda.body); + self.lambda_envs_stack.pop(); + + let lambda_fn_typ: ast::Type = ast::Type::Function(parameter_types, Box::new(ret_type)); + let lambda_fn = ast::Expression::Ident(ast::Ident { definition: Definition::Function(id), mutable: false, - location: None, - name, - typ, - }) + location: None, // TODO: This should match the location of the lambda expression + name: name.clone(), + typ: lambda_fn_typ, + }); + + let mut parameters = vec![]; + parameters.push((env_local_id, true, env_name.to_string(), env_typ)); + parameters.extend(converted_parameters); + + let unconstrained = false; + let function = ast::Function { id, name, parameters, body, return_type, unconstrained }; + self.push_function(id, function); + + let lambda_value = ast::Expression::Tuple(vec![env_ident, lambda_fn]); + ast::Expression::Block(vec![env_let_stmt, lambda_value]) } /// Implements std::unsafe::zeroed by returning an appropriate zeroed diff --git a/crates/noirc_frontend/src/node_interner.rs b/crates/noirc_frontend/src/node_interner.rs index f37daf45136..cac12be58e0 100644 --- a/crates/noirc_frontend/src/node_interner.rs +++ b/crates/noirc_frontend/src/node_interner.rs @@ -634,5 +634,6 @@ fn get_type_method_key(typ: &Type) -> Option { | Type::Constant(_) | Type::Error | Type::Struct(_, _) => None, + Type::Closure(_) => None, // TODO: Is this correct? How do we add methods to functions? Can we do the same for closures? } }