Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(ssa refactoring): Support basic recursive functions #1387

Merged
merged 2 commits into from
May 24, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 104 additions & 4 deletions crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,25 @@ impl<'function> PerFunctionContext<'function> {
else_destination,
}) => {
let condition = self.translate_value(*condition);
let then_block = self.translate_block(*then_destination, block_queue);
let else_block = self.translate_block(*else_destination, block_queue);
self.context.builder.terminate_with_jmpif(condition, then_block, else_block);

// See if the value of the condition is known, and if so only inline the reachable
// branch. This lets us inline some recursive functions without recurring forever.
let dfg = &mut self.context.builder.current_function.dfg;
match dfg.get_numeric_constant(condition) {
Some(constant) => {
let next_block =
if constant.is_zero() { *else_destination } else { *then_destination };
let next_block = self.translate_block(next_block, block_queue);
self.context.builder.terminate_with_jmp(next_block, vec![]);
}
None => {
let then_block = self.translate_block(*then_destination, block_queue);
let else_block = self.translate_block(*else_destination, block_queue);
self.context
.builder
.terminate_with_jmpif(condition, then_block, else_block);
}
}
None
}
Some(TerminatorInstruction::Return { return_values }) => {
Expand All @@ -429,7 +445,11 @@ impl<'function> PerFunctionContext<'function> {
#[cfg(test)]
mod test {
use crate::ssa_refactor::{
ir::{instruction::BinaryOp, map::Id, types::Type},
ir::{
instruction::{BinaryOp, TerminatorInstruction},
map::Id,
types::Type,
},
ssa_builder::FunctionBuilder,
};

Expand Down Expand Up @@ -529,4 +549,84 @@ mod test {
let inlined = ssa.inline_functions();
assert_eq!(inlined.functions.len(), 1);
}

#[test]
fn recursive_functions() {
// fn main f0 {
// b0():
// v0 = call factorial(Field 5)
// return v0
// }
// fn factorial f1 {
// b0(v0: Field):
// v1 = lt v0, Field 1
// jmpif v1, then: b1, else: b2
// b1():
// return Field 1
// b2():
// v2 = sub v0, Field 1
// v3 = call factorial(v2)
// v4 = mul v0, v3
// return v4
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("main".into(), main_id);

let factorial_id = Id::test_new(1);
let factorial = builder.import_function(factorial_id);

let five = builder.field_constant(5u128);
let results = builder.insert_call(factorial, vec![five], vec![Type::field()]).to_vec();
builder.terminate_with_return(results);

builder.new_function("factorial".into(), factorial_id);
let b1 = builder.insert_block();
let b2 = builder.insert_block();

let one = builder.field_constant(1u128);

let v0 = builder.add_parameter(Type::field());
let v1 = builder.insert_binary(v0, BinaryOp::Lt, one);
builder.terminate_with_jmpif(v1, b1, b2);

builder.switch_to_block(b1);
builder.terminate_with_return(vec![one]);

builder.switch_to_block(b2);
let factorial_id = builder.import_function(factorial_id);
let v2 = builder.insert_binary(v0, BinaryOp::Sub, one);
let v3 = builder.insert_call(factorial_id, vec![v2], vec![Type::field()])[0];
let v4 = builder.insert_binary(v0, BinaryOp::Mul, v3);
builder.terminate_with_return(vec![v4]);

let ssa = builder.finish();
assert_eq!(ssa.functions.len(), 2);

// Expected SSA:
//
// fn main f2 {
// b0():
// jmp b1()
// b1():
// return Field 120
// }
let inlined = ssa.inline_functions();
assert_eq!(inlined.functions.len(), 1);

let main = inlined.main();
let b1 = &main.dfg[b1];

match b1.terminator() {
Some(TerminatorInstruction::Return { return_values }) => {
assert_eq!(return_values.len(), 1);
let value = main
.dfg
.get_numeric_constant(return_values[0])
.expect("Expected a constant for the return value")
.to_u128();
assert_eq!(value, 120);
}
other => unreachable!("Unexpected terminator {other:?}"),
}
}
}