Skip to content

Commit

Permalink
feat: Initial work on rewriting closures to regular functions with hi…
Browse files Browse the repository at this point in the history
…dden env

This commit implements the following mechanism:

On a line where a lambda expression is encountered, we initialize a tuple for the captured lambda environment and we rewrite the lambda to a regular function taking this environment as an additional parameter. All calls to the closure are then modified to insert this hidden parameter.

In other words, the following code:

```
let x = some_value;
let closure = |a| x + a;

println(closure(10));
println(closure(20));
```

is rewritten to:

```
fn closure(env: (Field,), a: Field) -> Field {
  env.0 + a
}

let x = some_value;
let closure_env = (x,);

println(closure(closure_env, 10));
println(closure(closure_env, 20));
```

In the presence of nested closures, we propagate the captured variables implicitly through all intermediate closures:

```
let x = some_value;

let closure = |a, c|
  # here, `x` is initialized from the hidden env of the outer closure
  let inner_closure = |b|
    a + b + x

  inner_closure(c)
```

To make these transforms possible, the following changes were made to the logic of the HIR resolver and the monomorphization pass:

* In the HIR resolver pass, the code determines the precise list of variables captured by each lambda. Along with the list, we compute the index of each captured var within the parent closure's environment (when the capture is propagated).

* Introduction of a new `Closure` type in order to be able to recognize the call-sites that need the automatic environment variable treatment. It's a bit unfortunate that the Closure type is defined within the `AST` modules that are used to describe the output of the monomorphization pass, because we aim to eliminate all closures during the pass. A better solution would have been possible if the type check pass after HIR resolution was outputting types specific to the HIR pass (then the closures would exist only within this separate non-simplified type system).

* The majority of the work is in the Lambda processing step in the monomorphizer which performs the necessary transformations based on the above information.

Remaining things to do:

* There are a number of pending TODO items for various minor unresolved loose ends in the code.

* There are a lot of possible additional tests to be written.

* Update docs
  • Loading branch information
alehander92 committed Jul 18, 2023
1 parent 211e251 commit 56b7021
Show file tree
Hide file tree
Showing 10 changed files with 446 additions and 82 deletions.
1 change: 0 additions & 1 deletion crates/noirc_driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ pub fn compile_no_check(
main_function: FuncId,
) -> Result<CompiledProgram, FileDiagnostic> {
let program = monomorphize(main_function, &context.def_interner);

let (circuit, debug, abi) = if options.experimental_ssa {
experimental_create_circuit(program, options.show_ssa, options.show_output)?
} else {
Expand Down
1 change: 1 addition & 0 deletions crates/noirc_evaluator/src/ssa/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,7 @@ impl SsaContext {
self.log(enable_logging, "reduce", "\ninlining:");
inline::inline_tree(self, self.first_block, &decision)?;

self.log(enable_logging, "reduce", "\nmerging paths:");
block::merge_path(self, self.first_block, BlockId::dummy(), None)?;
//The CFG is now fully flattened, so we keep only the first block.
let mut to_remove = Vec::new();
Expand Down
175 changes: 143 additions & 32 deletions crates/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
//
// XXX: Resolver does not check for unused functions
use crate::hir_def::expr::{
HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCastExpression,
HirConstructorExpression, HirExpression, HirForExpression, HirIdent, HirIfExpression,
HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, HirMemberAccess,
HirMethodCallExpression, HirPrefixExpression,
HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCapturedVar,
HirCastExpression, HirConstructorExpression, HirExpression, HirForExpression, HirIdent,
HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral,
HirMemberAccess, HirMethodCallExpression, HirPrefixExpression,
};
use crate::token::Attribute;
use std::collections::{HashMap, HashSet};
Expand Down Expand Up @@ -57,6 +57,11 @@ type Scope = GenericScope<String, ResolverMeta>;
type ScopeTree = GenericScopeTree<String, ResolverMeta>;
type ScopeForest = GenericScopeForest<String, ResolverMeta>;

pub struct LambdaContext {
captures: Vec<HirCapturedVar>,
scope_index: usize,
}

/// The primary jobs of the Resolver are to validate that every variable found refers to exactly 1
/// definition in scope, and to convert the AST into the HIR.
///
Expand All @@ -80,12 +85,10 @@ pub struct Resolver<'a> {
/// were declared in.
generics: Vec<(Rc<String>, TypeVariable, Span)>,

/// Lambdas share the function scope of the function they're defined in,
/// so to identify whether they use any variables from the parent function
/// we keep track of the scope index a variable is declared in. When a lambda
/// is declared we push a scope and set this lambda_index to the scope index.
/// Any variable from a scope less than that must be from the parent function.
lambda_index: usize,
/// When resolving lambda expressions, we need to keep track of the variables
/// that are captured. We do this in order to create the hidden environment
/// parameter for the lambda function.
lambda_stack: Vec<LambdaContext>,
}

/// ResolverMetas are tagged onto each definition to track how many times they are used
Expand All @@ -111,7 +114,7 @@ impl<'a> Resolver<'a> {
self_type: None,
generics: Vec::new(),
errors: Vec::new(),
lambda_index: 0,
lambda_stack: Vec::new(),
file,
}
}
Expand All @@ -124,10 +127,6 @@ impl<'a> Resolver<'a> {
self.errors.push(err);
}

fn current_lambda_index(&self) -> usize {
self.scopes.current_scope_index()
}

/// Resolving a function involves interning the metadata
/// interning any statements inside of the function
/// and interning the function itself
Expand Down Expand Up @@ -276,25 +275,25 @@ impl<'a> Resolver<'a> {
//
// If a variable is not found, then an error is logged and a dummy id
// is returned, for better error reporting UX
fn find_variable_or_default(&mut self, name: &Ident) -> HirIdent {
fn find_variable_or_default(&mut self, name: &Ident) -> (HirIdent, usize) {
self.find_variable(name).unwrap_or_else(|error| {
self.push_err(error);
let id = DefinitionId::dummy_id();
let location = Location::new(name.span(), self.file);
HirIdent { location, id }
(HirIdent { location, id }, 0)
})
}

fn find_variable(&mut self, name: &Ident) -> Result<HirIdent, ResolverError> {
fn find_variable(&mut self, name: &Ident) -> Result<(HirIdent, usize), ResolverError> {
// Find the definition for this Ident
let scope_tree = self.scopes.current_scope_tree();
let variable = scope_tree.find(&name.0.contents);

let location = Location::new(name.span(), self.file);
if let Some((variable_found, _)) = variable {
if let Some((variable_found, scope)) = variable {
variable_found.num_times_used += 1;
let id = variable_found.ident.id;
Ok(HirIdent { location, id })
Ok((HirIdent { location, id }, scope))
} else {
Err(ResolverError::VariableNotDeclared {
name: name.0.contents.clone(),
Expand Down Expand Up @@ -478,24 +477,24 @@ impl<'a> Resolver<'a> {
}
}

fn get_ident_from_path(&mut self, path: Path) -> HirIdent {
fn get_ident_from_path(&mut self, path: Path) -> (HirIdent, usize) {
let location = Location::new(path.span(), self.file);

let error = match path.as_ident().map(|ident| self.find_variable(ident)) {
Some(Ok(ident)) => return ident,
Some(Ok(found)) => return found,
// Try to look it up as a global, but still issue the first error if we fail
Some(Err(error)) => match self.lookup_global(path) {
Ok(id) => return HirIdent { location, id },
Ok(id) => return (HirIdent { location, id }, 0),
Err(_) => error,
},
None => match self.lookup_global(path) {
Ok(id) => return HirIdent { location, id },
Ok(id) => return (HirIdent { location, id }, 0),
Err(error) => error,
},
};
self.push_err(error);
let id = DefinitionId::dummy_id();
HirIdent { location, id }
(HirIdent { location, id }, 0)
}

/// Translates an UnresolvedType to a Type
Expand Down Expand Up @@ -774,12 +773,15 @@ impl<'a> Resolver<'a> {
Self::find_numeric_generics_in_type(field, found);
}
}

Type::Function(parameters, return_type) => {
for parameter in parameters {
Self::find_numeric_generics_in_type(parameter, found);
}
Self::find_numeric_generics_in_type(return_type, found);
}
Type::Closure(func) => Self::find_numeric_generics_in_type(func, found),

Type::Struct(struct_type, generics) => {
for (i, generic) in generics.iter().enumerate() {
if let Type::NamedGeneric(type_variable, name) = generic {
Expand Down Expand Up @@ -841,7 +843,7 @@ impl<'a> Resolver<'a> {
fn resolve_lvalue(&mut self, lvalue: LValue) -> HirLValue {
match lvalue {
LValue::Ident(ident) => {
HirLValue::Ident(self.find_variable_or_default(&ident), Type::Error)
HirLValue::Ident(self.find_variable_or_default(&ident).0, Type::Error)
}
LValue::MemberAccess { object, field_name } => {
let object = Box::new(self.resolve_lvalue(*object));
Expand Down Expand Up @@ -889,7 +891,52 @@ impl<'a> Resolver<'a> {
// Otherwise, then it is referring to an Identifier
// This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10;
// If the expression is a singular indent, we search the resolver's current scope as normal.
let hir_ident = self.get_ident_from_path(path);
let (hir_ident, var_scope_index) = self.get_ident_from_path(path);

if hir_ident.id != DefinitionId::dummy_id() {
match self.interner.definition(hir_ident.id).kind {
DefinitionKind::Function(_) => {}
DefinitionKind::Global(_) => {}
DefinitionKind::GenericType(_) => {}
// We ignore the above definition kinds because only local variables can be captured by closures.
DefinitionKind::Local(_) => {
let mut transitive_capture_index: Option<usize> = None;

for lambda_index in 0..self.lambda_stack.len() {
if self.lambda_stack[lambda_index].scope_index > var_scope_index {
// Beware: the same variable may be captured multiple times, so we check
// for its presence before adding the capture below.
let pos = self.lambda_stack[lambda_index]
.captures
.iter()
.position(|capture| capture.ident.id == hir_ident.id);

if pos.is_none() {
self.lambda_stack[lambda_index].captures.push(
HirCapturedVar {
ident: hir_ident,
transitive_capture_index,
},
);
}

if lambda_index + 1 < self.lambda_stack.len() {
// There is more than one closure between the current scope and
// the scope of the variable, so this is a propagated capture.
// We need to track the transitive capture index as we go up in
// the closure stack.
transitive_capture_index = Some(pos.unwrap_or(
// If this was a fresh capture, we added it to the end of
// the captures vector:
self.lambda_stack[lambda_index].captures.len() - 1,
))
}
}
}
}
}
}

HirExpression::Ident(hir_ident)
}
ExpressionKind::Prefix(prefix) => {
Expand Down Expand Up @@ -1010,8 +1057,10 @@ impl<'a> Resolver<'a> {
// We must stay in the same function scope as the parent function to allow for closures
// to capture variables. This is currently limited to immutable variables.
ExpressionKind::Lambda(lambda) => self.in_new_scope(|this| {
let new_index = this.current_lambda_index();
let old_index = std::mem::replace(&mut this.lambda_index, new_index);
let scope_index = this.scopes.current_scope_index();

this.lambda_stack
.push(LambdaContext { captures: Vec::new(), scope_index: scope_index });

let parameters = vecmap(lambda.parameters, |(pattern, typ)| {
let parameter = DefinitionKind::Local(None);
Expand All @@ -1021,8 +1070,14 @@ impl<'a> Resolver<'a> {
let return_type = this.resolve_inferred_type(lambda.return_type);
let body = this.resolve_expression(lambda.body);

this.lambda_index = old_index;
HirExpression::Lambda(HirLambda { parameters, return_type, body })
let lambda_context = this.lambda_stack.pop().unwrap();

HirExpression::Lambda(HirLambda {
parameters,
return_type,
body,
captures: lambda_context.captures,
})
}),
};

Expand Down Expand Up @@ -1299,6 +1354,7 @@ pub fn verify_mutable_reference(interner: &NodeInterner, rhs: ExprId) -> Result<
#[cfg(test)]
mod test {

use core::panic;
use std::collections::HashMap;

use fm::FileId;
Expand All @@ -1322,7 +1378,9 @@ mod test {
// and functions can be forward declared
fn resolve_src_code(src: &str, func_namespace: Vec<&str>) -> Vec<ResolverError> {
let (program, errors) = parse_program(src);
assert!(errors.is_empty());
if !errors.is_empty() {
panic!("Unexpected parse errors in test code: {:?}", errors);
}

let mut interner = NodeInterner::default();

Expand Down Expand Up @@ -1530,6 +1588,59 @@ mod test {
assert!(errors.is_empty());
}

#[test]
fn resolve_basic_closure() {
let src = r#"
fn main(x : Field) -> pub Field {
let closure = |y| y + x;
closure(x)
}
"#;

let errors = resolve_src_code(src, vec!["main", "foo"]);
if !errors.is_empty() {
println!("Unexpected errors: {:?}", errors);
assert!(false); // there should be no errors
}
}

#[test]
fn resolve_complex_closures() {
let src = r#"
fn main(x: Field) -> pub Field {
let closure_without_captures = |x| x + x;
let a = closure_without_captures(1);
let closure_capturing_a_param = |y| y + x;
let b = closure_capturing_a_param(2);
let closure_capturing_a_local_var = |y| y + b;
let c = closure_capturing_a_local_var(3);
let closure_with_transitive_captures = |y| {
let d = 5;
let nested_closure = |z| {
let doubly_nested_closure = |w| w + x + b;
a + z + y + d + x + doubly_nested_closure(4) + x + y
};
let res = nested_closure(5);
res
};
a + b + c + closure_with_transitive_captures(6)
}
"#;

let errors = resolve_src_code(src, vec!["main", "foo"]);
if !errors.is_empty() {
println!("Unexpected errors: {:?}", errors);
assert!(false); // there should be no errors
}

// TODO: Create a more sophisticated set of search functions over the HIR, so we can check
// that the correct variables are captured in each closure
}

fn path_unresolved_error(err: ResolverError, expected_unresolved_path: &str) {
match err {
ResolverError::PathResolutionError(PathResolutionError::Unresolved(name)) => {
Expand Down
Loading

0 comments on commit 56b7021

Please sign in to comment.