Skip to content

Commit

Permalink
feat: allow bitshifts to be represented in SSA for brillig (#4301)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #4276 

## Summary\*

Quick and dirty implementation of replacing bitshifts in SSA once
function inlining has occurred. Not a fan of how we are juggling the
instructions but will need to think about how to do this cleanly.

cc @sirasistant 

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[Exceptional Case]** Documentation to be submitted in a separate
PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
TomAFrench authored Feb 8, 2024
1 parent 1786d8a commit d86ff1a
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,8 @@ pub(crate) fn convert_ssa_binary_op_to_brillig_binary_op(
BinaryOp::And => BinaryIntOp::And,
BinaryOp::Or => BinaryIntOp::Or,
BinaryOp::Xor => BinaryIntOp::Xor,
BinaryOp::Shl => BinaryIntOp::Shl,
BinaryOp::Shr => BinaryIntOp::Shr,
};

BrilligBinaryOp::Integer { op: operation, bit_size }
Expand Down
5 changes: 2 additions & 3 deletions compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ impl DebugToString for BinaryIntOp {
BinaryIntOp::And => "&&".into(),
BinaryIntOp::Or => "||".into(),
BinaryIntOp::Xor => "^".into(),
BinaryIntOp::Shl | BinaryIntOp::Shr => {
unreachable!("bit shift should have been replaced")
}
BinaryIntOp::Shl => "<<".into(),
BinaryIntOp::Shr => ">>".into(),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pub(crate) fn optimize_into_acir(
// and this pass is missed, slice merging will fail inside of flattening.
.run_pass(Ssa::mem2reg, "After Mem2Reg:")
.run_pass(Ssa::flatten_cfg, "After Flattening:")
.run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:")
// Run mem2reg once more with the flattened CFG to catch any remaining loads/stores
.run_pass(Ssa::mem2reg, "After Mem2Reg:")
.run_pass(Ssa::fold_constants, "After Constant Folding:")
Expand Down
3 changes: 3 additions & 0 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,9 @@ impl Context {
bit_count,
self.current_side_effects_enabled_var,
),
BinaryOp::Shl | BinaryOp::Shr => unreachable!(
"ICE - bit shift operators do not exist in ACIR and should have been replaced"
),
}
}

Expand Down
112 changes: 1 addition & 111 deletions compiler/noirc_evaluator/src/ssa/function_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ use super::{
basic_block::BasicBlock,
dfg::{CallStack, InsertInstructionResult},
function::RuntimeType,
instruction::{ConstrainError, Endian, InstructionId, Intrinsic},
types::NumericType,
instruction::{ConstrainError, InstructionId, Intrinsic},
},
ssa_gen::Ssa,
};
Expand Down Expand Up @@ -279,115 +278,6 @@ impl FunctionBuilder {
self.insert_instruction(Instruction::Call { func, arguments }, Some(result_types)).results()
}

/// Insert ssa instructions which computes lhs << rhs by doing lhs*2^rhs
/// and truncate the result to bit_size
pub(crate) fn insert_wrapping_shift_left(
&mut self,
lhs: ValueId,
rhs: ValueId,
bit_size: u32,
) -> ValueId {
let base = self.field_constant(FieldElement::from(2_u128));
let typ = self.current_function.dfg.type_of_value(lhs);
let (max_bit, pow) =
if let Some(rhs_constant) = self.current_function.dfg.get_numeric_constant(rhs) {
// Happy case is that we know precisely by how many bits the the integer will
// increase: lhs_bit_size + rhs
let bit_shift_size = rhs_constant.to_u128() as u32;

let (rhs_bit_size_pow_2, overflows) = 2_u128.overflowing_pow(bit_shift_size);
if overflows {
assert!(bit_size < 128, "ICE - shift left with big integers are not supported");
if bit_size < 128 {
let zero = self.numeric_constant(FieldElement::zero(), typ);
return InsertInstructionResult::SimplifiedTo(zero).first();
}
}
let pow = self.numeric_constant(FieldElement::from(rhs_bit_size_pow_2), typ);

let max_lhs_bits = self.current_function.dfg.get_value_max_num_bits(lhs);

(max_lhs_bits + bit_shift_size, pow)
} else {
// we use a predicate to nullify the result in case of overflow
let bit_size_var =
self.numeric_constant(FieldElement::from(bit_size as u128), typ.clone());
let overflow = self.insert_binary(rhs, BinaryOp::Lt, bit_size_var);
let predicate = self.insert_cast(overflow, typ.clone());
// we can safely cast to unsigned because overflow_checks prevent bit-shift with a negative value
let rhs_unsigned = self.insert_cast(rhs, Type::unsigned(bit_size));
let pow = self.pow(base, rhs_unsigned);
let pow = self.insert_cast(pow, typ);
(FieldElement::max_num_bits(), self.insert_binary(predicate, BinaryOp::Mul, pow))
};

if max_bit <= bit_size {
self.insert_binary(lhs, BinaryOp::Mul, pow)
} else {
let result = self.insert_binary(lhs, BinaryOp::Mul, pow);
self.insert_truncate(result, bit_size, max_bit)
}
}

/// Insert ssa instructions which computes lhs >> rhs by doing lhs/2^rhs
pub(crate) fn insert_shift_right(
&mut self,
lhs: ValueId,
rhs: ValueId,
bit_size: u32,
) -> ValueId {
let lhs_typ = self.type_of_value(lhs);
let base = self.field_constant(FieldElement::from(2_u128));
// we can safely cast to unsigned because overflow_checks prevent bit-shift with a negative value
let rhs_unsigned = self.insert_cast(rhs, Type::unsigned(bit_size));
let pow = self.pow(base, rhs_unsigned);
// We need at least one more bit for the case where rhs == bit_size
let div_type = Type::unsigned(bit_size + 1);
let casted_lhs = self.insert_cast(lhs, div_type.clone());
let casted_pow = self.insert_cast(pow, div_type);
let div_result = self.insert_binary(casted_lhs, BinaryOp::Div, casted_pow);
// We have to cast back to the original type
self.insert_cast(div_result, lhs_typ)
}

/// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs
/// Pseudo-code of the computation:
/// let mut r = 1;
/// let rhs_bits = to_bits(rhs);
/// for i in 1 .. bit_size + 1 {
/// let r_squared = r * r;
/// let b = rhs_bits[bit_size - i];
/// r = (r_squared * lhs * b) + (1 - b) * r_squared;
/// }
pub(crate) fn pow(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
let typ = self.current_function.dfg.type_of_value(rhs);
if let Type::Numeric(NumericType::Unsigned { bit_size }) = typ {
let to_bits = self.import_intrinsic_id(Intrinsic::ToBits(Endian::Little));
let length = self.field_constant(FieldElement::from(bit_size as i128));
let result_types =
vec![Type::field(), Type::Array(Rc::new(vec![Type::bool()]), bit_size as usize)];
let rhs_bits = self.insert_call(to_bits, vec![rhs, length], result_types);
let rhs_bits = rhs_bits[1];
let one = self.field_constant(FieldElement::one());
let mut r = one;
for i in 1..bit_size + 1 {
let r_squared = self.insert_binary(r, BinaryOp::Mul, r);
let a = self.insert_binary(r_squared, BinaryOp::Mul, lhs);
let idx = self.field_constant(FieldElement::from((bit_size - i) as i128));
let b = self.insert_array_get(rhs_bits, idx, Type::bool());
let not_b = self.insert_not(b);
let b = self.insert_cast(b, Type::field());
let not_b = self.insert_cast(not_b, Type::field());
let r1 = self.insert_binary(a, BinaryOp::Mul, b);
let r2 = self.insert_binary(r_squared, BinaryOp::Mul, not_b);
r = self.insert_binary(r1, BinaryOp::Add, r2);
}
r
} else {
unreachable!("Value must be unsigned in power operation");
}
}

/// Insert an instruction to extract an element from an array
pub(crate) fn insert_array_get(
&mut self,
Expand Down
34 changes: 33 additions & 1 deletion compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ pub(crate) enum BinaryOp {
Or,
/// Bitwise xor (^)
Xor,
/// Bitshift left (<<)
Shl,
/// Bitshift right (>>)
Shr,
}

impl std::fmt::Display for BinaryOp {
Expand All @@ -53,6 +57,8 @@ impl std::fmt::Display for BinaryOp {
BinaryOp::And => write!(f, "and"),
BinaryOp::Or => write!(f, "or"),
BinaryOp::Xor => write!(f, "xor"),
BinaryOp::Shl => write!(f, "shl"),
BinaryOp::Shr => write!(f, "shr"),
}
}
}
Expand Down Expand Up @@ -215,7 +221,27 @@ impl Binary {
return SimplifyResult::SimplifiedTo(zero);
}
}
}
BinaryOp::Shl => return SimplifyResult::None,
BinaryOp::Shr => {
// Bit shifts by constants can be treated as divisions.
if let Some(rhs_const) = rhs {
if rhs_const >= FieldElement::from(operand_type.bit_size() as u128) {
// Shifting by the full width of the operand type, any `lhs` goes to zero.
let zero = dfg.make_constant(FieldElement::zero(), operand_type);
return SimplifyResult::SimplifiedTo(zero);
}

// `two_pow_rhs` is limited to be at most `2 ^ {operand_bitsize - 1}` so it fits in `operand_type`.
let two_pow_rhs = FieldElement::from(2u128).pow(&rhs_const);
let two_pow_rhs = dfg.make_constant(two_pow_rhs, operand_type);
return SimplifyResult::SimplifiedToInstruction(Instruction::binary(
BinaryOp::Div,
self.lhs,
two_pow_rhs,
));
}
}
};
SimplifyResult::None
}
}
Expand Down Expand Up @@ -314,6 +340,8 @@ impl BinaryOp {
BinaryOp::And => None,
BinaryOp::Or => None,
BinaryOp::Xor => None,
BinaryOp::Shl => None,
BinaryOp::Shr => None,
}
}

Expand All @@ -329,6 +357,8 @@ impl BinaryOp {
BinaryOp::Xor => |x, y| Some(x ^ y),
BinaryOp::Eq => |x, y| Some((x == y) as u128),
BinaryOp::Lt => |x, y| Some((x < y) as u128),
BinaryOp::Shl => |x, y| Some(x << y),
BinaryOp::Shr => |x, y| Some(x >> y),
}
}

Expand All @@ -344,6 +374,8 @@ impl BinaryOp {
BinaryOp::Xor => |x, y| Some(x ^ y),
BinaryOp::Eq => |x, y| Some((x == y) as i128),
BinaryOp::Lt => |x, y| Some((x < y) as i128),
BinaryOp::Shl => |x, y| Some(x << y),
BinaryOp::Shr => |x, y| Some(x >> y),
}
}
}
1 change: 1 addition & 0 deletions compiler/noirc_evaluator/src/ssa/opt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ mod die;
pub(crate) mod flatten_cfg;
mod inlining;
mod mem2reg;
mod remove_bit_shifts;
mod simplify_cfg;
mod unrolling;
Loading

0 comments on commit d86ff1a

Please sign in to comment.