Skip to content

Commit

Permalink
feat: simplify jmpifs by reversing branches if condition is negated (
Browse files Browse the repository at this point in the history
…#5891)

Co-authored-by: Maxim Vezenov <mvezenov@gmail.com>
  • Loading branch information
TomAFrench and vezenovm authored Nov 26, 2024
1 parent 30f8378 commit ba7a568
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 1 deletion.
112 changes: 111 additions & 1 deletion compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ use crate::ssa::{
basic_block::BasicBlockId,
cfg::ControlFlowGraph,
function::{Function, RuntimeType},
instruction::TerminatorInstruction,
instruction::{Instruction, TerminatorInstruction},
value::Value,
},
ssa_gen::Ssa,
};
Expand All @@ -31,6 +32,7 @@ impl Ssa {
/// 4. Removing any blocks which have no instructions other than a single terminating jmp.
/// 5. Replacing any jmpifs with constant conditions with jmps. If this causes the block to have
/// only 1 successor then (2) also will be applied.
/// 6. Replacing any jmpifs with a negated condition with a jmpif with a un-negated condition and reversed branches.
///
/// Currently, 1 is unimplemented.
#[tracing::instrument(level = "trace", skip(self))]
Expand All @@ -55,6 +57,8 @@ impl Function {
stack.extend(self.dfg[block].successors().filter(|block| !visited.contains(block)));
}

check_for_negated_jmpif_condition(self, block, &mut cfg);

// This call is before try_inline_into_predecessor so that if it succeeds in changing a
// jmpif into a jmp, the block may then be inlined entirely into its predecessor in try_inline_into_predecessor.
check_for_constant_jmpif(self, block, &mut cfg);
Expand Down Expand Up @@ -184,6 +188,55 @@ fn check_for_double_jmp(function: &mut Function, block: BasicBlockId, cfg: &mut
cfg.recompute_block(function, block);
}

/// Optimize a jmpif on a negated condition by swapping the branches.
fn check_for_negated_jmpif_condition(
function: &mut Function,
block: BasicBlockId,
cfg: &mut ControlFlowGraph,
) {
if matches!(function.runtime(), RuntimeType::Acir(_)) {
// Swapping the `then` and `else` branches of a `JmpIf` within an ACIR function
// can result in the situation where the branches merge together again in the `then` block, e.g.
//
// acir(inline) fn main f0 {
// b0(v0: u1):
// jmpif v0 then: b2, else: b1
// b2():
// return
// b1():
// jmp b2()
// }
//
// This breaks the `flatten_cfg` pass as it assumes that merges only happen in
// the `else` block or a 3rd block.
//
// See: https://github.com/noir-lang/noir/pull/5891#issuecomment-2500219428
return;
}

if let Some(TerminatorInstruction::JmpIf {
condition,
then_destination,
else_destination,
call_stack,
}) = function.dfg[block].terminator()
{
if let Value::Instruction { instruction, .. } = function.dfg[*condition] {
if let Instruction::Not(negated_condition) = function.dfg[instruction] {
let call_stack = call_stack.clone();
let jmpif = TerminatorInstruction::JmpIf {
condition: negated_condition,
then_destination: *else_destination,
else_destination: *then_destination,
call_stack,
};
function.dfg[block].set_terminator(jmpif);
cfg.recompute_block(function, block);
}
}
}
}

/// If the given block has block parameters, replace them with the jump arguments from the predecessor.
///
/// Currently, if this function is needed, `try_inline_into_predecessor` will also always apply,
Expand Down Expand Up @@ -246,6 +299,8 @@ mod test {
map::Id,
types::Type,
},
opt::assert_normalized_ssa_equals,
Ssa,
};
use acvm::acir::AcirField;

Expand Down Expand Up @@ -359,4 +414,59 @@ mod test {
other => panic!("Unexpected terminator {other:?}"),
}
}

#[test]
fn swap_negated_jmpif_branches_in_brillig() {
let src = "
brillig(inline) fn main f0 {
b0(v0: u1):
v1 = allocate -> &mut Field
store Field 0 at v1
v3 = not v0
jmpif v3 then: b1, else: b2
b1():
store Field 2 at v1
jmp b2()
b2():
v5 = load v1 -> Field
v6 = eq v5, Field 2
constrain v5 == Field 2
return
}";
let ssa = Ssa::from_str(src).unwrap();

let expected = "
brillig(inline) fn main f0 {
b0(v0: u1):
v1 = allocate -> &mut Field
store Field 0 at v1
v3 = not v0
jmpif v0 then: b2, else: b1
b2():
v5 = load v1 -> Field
v6 = eq v5, Field 2
constrain v5 == Field 2
return
b1():
store Field 2 at v1
jmp b2()
}";
assert_normalized_ssa_equals(ssa.simplify_cfg(), expected);
}

#[test]
fn does_not_swap_negated_jmpif_branches_in_acir() {
let src = "
acir(inline) fn main f0 {
b0(v0: u1):
v1 = not v0
jmpif v1 then: b1, else: b2
b1():
jmp b2()
b2():
return
}";
let ssa = Ssa::from_str(src).unwrap();
assert_normalized_ssa_equals(ssa.simplify_cfg(), src);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "negated_jmpif_condition"
type = "bin"
authors = [""]

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x = "2"
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
fn main(mut x: Field) {
let mut q = 0;

if x != 10 {
q = 2;
}

assert(q == 2);
}

0 comments on commit ba7a568

Please sign in to comment.