Skip to content

Commit

Permalink
chore(ssa refactoring): Support basic recursive functions (#1387)
Browse files Browse the repository at this point in the history
* Support basic recursive functions

* Add comment
  • Loading branch information
jfecher authored May 24, 2023
1 parent b938c7e commit cceaca0
Showing 1 changed file with 104 additions and 4 deletions.
108 changes: 104 additions & 4 deletions crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,25 @@ impl<'function> PerFunctionContext<'function> {
}
TerminatorInstruction::JmpIf { condition, then_destination, else_destination } => {
let condition = self.translate_value(*condition);
let then_block = self.translate_block(*then_destination, block_queue);
let else_block = self.translate_block(*else_destination, block_queue);
self.context.builder.terminate_with_jmpif(condition, then_block, else_block);

// See if the value of the condition is known, and if so only inline the reachable
// branch. This lets us inline some recursive functions without recurring forever.
let dfg = &mut self.context.builder.current_function.dfg;
match dfg.get_numeric_constant(condition) {
Some(constant) => {
let next_block =
if constant.is_zero() { *else_destination } else { *then_destination };
let next_block = self.translate_block(next_block, block_queue);
self.context.builder.terminate_with_jmp(next_block, vec![]);
}
None => {
let then_block = self.translate_block(*then_destination, block_queue);
let else_block = self.translate_block(*else_destination, block_queue);
self.context
.builder
.terminate_with_jmpif(condition, then_block, else_block);
}
}
None
}
TerminatorInstruction::Return { return_values } => {
Expand All @@ -424,7 +440,11 @@ impl<'function> PerFunctionContext<'function> {
#[cfg(test)]
mod test {
use crate::ssa_refactor::{
ir::{instruction::BinaryOp, map::Id, types::Type},
ir::{
instruction::{BinaryOp, TerminatorInstruction},
map::Id,
types::Type,
},
ssa_builder::FunctionBuilder,
};

Expand Down Expand Up @@ -524,4 +544,84 @@ mod test {
let inlined = ssa.inline_functions();
assert_eq!(inlined.functions.len(), 1);
}

#[test]
fn recursive_functions() {
// fn main f0 {
// b0():
// v0 = call factorial(Field 5)
// return v0
// }
// fn factorial f1 {
// b0(v0: Field):
// v1 = lt v0, Field 1
// jmpif v1, then: b1, else: b2
// b1():
// return Field 1
// b2():
// v2 = sub v0, Field 1
// v3 = call factorial(v2)
// v4 = mul v0, v3
// return v4
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("main".into(), main_id);

let factorial_id = Id::test_new(1);
let factorial = builder.import_function(factorial_id);

let five = builder.field_constant(5u128);
let results = builder.insert_call(factorial, vec![five], vec![Type::field()]).to_vec();
builder.terminate_with_return(results);

builder.new_function("factorial".into(), factorial_id);
let b1 = builder.insert_block();
let b2 = builder.insert_block();

let one = builder.field_constant(1u128);

let v0 = builder.add_parameter(Type::field());
let v1 = builder.insert_binary(v0, BinaryOp::Lt, one);
builder.terminate_with_jmpif(v1, b1, b2);

builder.switch_to_block(b1);
builder.terminate_with_return(vec![one]);

builder.switch_to_block(b2);
let factorial_id = builder.import_function(factorial_id);
let v2 = builder.insert_binary(v0, BinaryOp::Sub, one);
let v3 = builder.insert_call(factorial_id, vec![v2], vec![Type::field()])[0];
let v4 = builder.insert_binary(v0, BinaryOp::Mul, v3);
builder.terminate_with_return(vec![v4]);

let ssa = builder.finish();
assert_eq!(ssa.functions.len(), 2);

// Expected SSA:
//
// fn main f2 {
// b0():
// jmp b1()
// b1():
// return Field 120
// }
let inlined = ssa.inline_functions();
assert_eq!(inlined.functions.len(), 1);

let main = inlined.main();
let b1 = &main.dfg[b1];

match b1.terminator() {
Some(TerminatorInstruction::Return { return_values }) => {
assert_eq!(return_values.len(), 1);
let value = main
.dfg
.get_numeric_constant(return_values[0])
.expect("Expected a constant for the return value")
.to_u128();
assert_eq!(value, 120);
}
other => unreachable!("Unexpected terminator {other:?}"),
}
}
}

0 comments on commit cceaca0

Please sign in to comment.