Skip to content

Commit

Permalink
feat: Add support for bitshifts by distances known at runtime (#2072)
Browse files Browse the repository at this point in the history
* remove shr and shl from ssa instruction

* move bit_shift_runtime test to test_data

* code review, fix typo

* Forbid signed integers for bit shift and fix brillig failing test

* Check for signeness also during the delayed checks

* Add missing method

* Code review

* Code review
  • Loading branch information
guipublic authored Aug 2, 2023
1 parent 50b2816 commit b0fbc53
Show file tree
Hide file tree
Showing 14 changed files with 132 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "bit_shifts_runtime"
authors = [""]
compiler_version = "0.1"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x = 64
y = 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
fn main(x: u64, y: u64) {
// runtime shifts on comptime values
assert(64 << y == 128);
assert(64 >> y == 32);

// runtime shifts on runtime values
assert(x << y == 128);
assert(x >> y == 32);
}
24 changes: 11 additions & 13 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,10 @@ impl<'block> BrilligBlock<'block> {
dfg.instruction_results(instruction_id)[0],
dfg,
);

let heap_vec = self.brillig_context.extract_heap_vector(target_slice);
self.brillig_context.radix_instruction(
source,
self.function_context.extract_heap_vector(target_slice),
heap_vec,
radix,
limb_count,
matches!(endianness, Endian::Big),
Expand All @@ -355,10 +355,10 @@ impl<'block> BrilligBlock<'block> {
);

let radix = self.brillig_context.make_constant(2_usize.into());

let heap_vec = self.brillig_context.extract_heap_vector(target_slice);
self.brillig_context.radix_instruction(
source,
self.function_context.extract_heap_vector(target_slice),
heap_vec,
radix,
limb_count,
matches!(endianness, Endian::Big),
Expand Down Expand Up @@ -589,7 +589,7 @@ impl<'block> BrilligBlock<'block> {
dfg.instruction_results(instruction_id)[0],
dfg,
);
let target_vector = self.function_context.extract_heap_vector(target_variable);
let target_vector = self.brillig_context.extract_heap_vector(target_variable);
let item_value = self.convert_ssa_register_value(arguments[1], dfg);
slice_push_back_operation(
self.brillig_context,
Expand All @@ -604,7 +604,7 @@ impl<'block> BrilligBlock<'block> {
dfg.instruction_results(instruction_id)[0],
dfg,
);
let target_vector = self.function_context.extract_heap_vector(target_variable);
let target_vector = self.brillig_context.extract_heap_vector(target_variable);
let item_value = self.convert_ssa_register_value(arguments[1], dfg);
slice_push_front_operation(
self.brillig_context,
Expand All @@ -618,7 +618,7 @@ impl<'block> BrilligBlock<'block> {

let target_variable =
self.function_context.create_variable(self.brillig_context, results[0], dfg);
let target_vector = self.function_context.extract_heap_vector(target_variable);
let target_vector = self.brillig_context.extract_heap_vector(target_variable);

let pop_item = self.function_context.create_register_variable(
self.brillig_context,
Expand All @@ -643,7 +643,7 @@ impl<'block> BrilligBlock<'block> {
);
let target_variable =
self.function_context.create_variable(self.brillig_context, results[1], dfg);
let target_vector = self.function_context.extract_heap_vector(target_variable);
let target_vector = self.brillig_context.extract_heap_vector(target_variable);

slice_pop_front_operation(
self.brillig_context,
Expand All @@ -659,7 +659,7 @@ impl<'block> BrilligBlock<'block> {
let target_variable =
self.function_context.create_variable(self.brillig_context, results[0], dfg);

let target_vector = self.function_context.extract_heap_vector(target_variable);
let target_vector = self.brillig_context.extract_heap_vector(target_variable);
slice_insert_operation(
self.brillig_context,
target_vector,
Expand All @@ -674,7 +674,7 @@ impl<'block> BrilligBlock<'block> {

let target_variable =
self.function_context.create_variable(self.brillig_context, results[0], dfg);
let target_vector = self.function_context.extract_heap_vector(target_variable);
let target_vector = self.brillig_context.extract_heap_vector(target_variable);

let removed_item_register = self.function_context.create_register_variable(
self.brillig_context,
Expand Down Expand Up @@ -877,7 +877,7 @@ impl<'block> BrilligBlock<'block> {
Type::Slice(_) => {
let variable =
self.function_context.create_variable(self.brillig_context, result, dfg);
let vector = self.function_context.extract_heap_vector(variable);
let vector = self.brillig_context.extract_heap_vector(variable);

// Set the pointer to the current stack frame
// The stack pointer will then be update by the caller of this method
Expand Down Expand Up @@ -981,8 +981,6 @@ pub(crate) fn convert_ssa_binary_op_to_brillig_binary_op(
BinaryOp::And => BinaryIntOp::And,
BinaryOp::Or => BinaryIntOp::Or,
BinaryOp::Xor => BinaryIntOp::Xor,
BinaryOp::Shl => BinaryIntOp::Shl,
BinaryOp::Shr => BinaryIntOp::Shr,
};

BrilligBinaryOp::Integer { op: operation, bit_size }
Expand Down
7 changes: 0 additions & 7 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,6 @@ impl FunctionContext {
}
}

pub(crate) fn extract_heap_vector(&self, variable: RegisterOrMemory) -> HeapVector {
match variable {
RegisterOrMemory::HeapVector(vector) => vector,
_ => unreachable!("ICE: Expected vector, got {variable:?}"),
}
}

/// Collects the registers that a given variable is stored in.
pub(crate) fn extract_registers(&self, variable: RegisterOrMemory) -> Vec<RegisterIndex> {
match variable {
Expand Down
12 changes: 12 additions & 0 deletions crates/noirc_evaluator/src/brillig/brillig_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,18 @@ impl BrilligContext {
self.deallocate_register(end_value_register);
self.deallocate_register(index_at_end_of_array);
}

pub(crate) fn extract_heap_vector(&mut self, variable: RegisterOrMemory) -> HeapVector {
match variable {
RegisterOrMemory::HeapVector(vector) => vector,
RegisterOrMemory::HeapArray(array) => {
let size = self.allocate_register();
self.const_instruction(size, array.size.into());
HeapVector { pointer: array.pointer, size }
}
_ => unreachable!("ICE: Expected vector, got {variable:?}"),
}
}
}

/// Type to encapsulate the binary operation types in Brillig
Expand Down
5 changes: 3 additions & 2 deletions crates/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ impl DebugToString for BinaryIntOp {
BinaryIntOp::And => "&&".into(),
BinaryIntOp::Or => "||".into(),
BinaryIntOp::Xor => "^".into(),
BinaryIntOp::Shl => "<<".into(),
BinaryIntOp::Shr => ">>".into(),
BinaryIntOp::Shl | BinaryIntOp::Shr => {
unreachable!("bit shift should have been replaced")
}
}
}
}
Expand Down
7 changes: 0 additions & 7 deletions crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,13 +796,6 @@ impl Context {
bit_count,
self.current_side_effects_enabled_var,
),
BinaryOp::Shl => self.acir_context.shift_left_var(lhs, rhs, binary_type),
BinaryOp::Shr => self.acir_context.shift_right_var(
lhs,
rhs,
binary_type,
self.current_side_effects_enabled_var,
),
BinaryOp::Xor => self.acir_context.xor_var(lhs, rhs, binary_type),
BinaryOp::And => self.acir_context.and_var(lhs, rhs, binary_type),
BinaryOp::Or => self.acir_context.or_var(lhs, rhs, binary_type),
Expand Down
20 changes: 0 additions & 20 deletions crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -760,16 +760,6 @@ impl Binary {
return SimplifyResult::SimplifiedTo(zero);
}
}
BinaryOp::Shl => {
if rhs_is_zero {
return SimplifyResult::SimplifiedTo(self.lhs);
}
}
BinaryOp::Shr => {
if rhs_is_zero {
return SimplifyResult::SimplifiedTo(self.lhs);
}
}
}
SimplifyResult::None
}
Expand Down Expand Up @@ -825,8 +815,6 @@ impl BinaryOp {
BinaryOp::And => None,
BinaryOp::Or => None,
BinaryOp::Xor => None,
BinaryOp::Shl => None,
BinaryOp::Shr => None,
}
}

Expand All @@ -840,8 +828,6 @@ impl BinaryOp {
BinaryOp::And => |x, y| Some(x & y),
BinaryOp::Or => |x, y| Some(x | y),
BinaryOp::Xor => |x, y| Some(x ^ y),
BinaryOp::Shl => |x, y| x.checked_shl(y.try_into().ok()?),
BinaryOp::Shr => |x, y| Some(x >> y),
BinaryOp::Eq => |x, y| Some((x == y) as u128),
BinaryOp::Lt => |x, y| Some((x < y) as u128),
}
Expand Down Expand Up @@ -882,10 +868,6 @@ pub(crate) enum BinaryOp {
Or,
/// Bitwise xor (^)
Xor,
/// Shift lhs left by rhs bits (<<)
Shl,
/// Shift lhs right by rhs bits (>>)
Shr,
}

impl std::fmt::Display for BinaryOp {
Expand All @@ -901,8 +883,6 @@ impl std::fmt::Display for BinaryOp {
BinaryOp::And => write!(f, "and"),
BinaryOp::Or => write!(f, "or"),
BinaryOp::Xor => write!(f, "xor"),
BinaryOp::Shl => write!(f, "shl"),
BinaryOp::Shr => write!(f, "shr"),
}
}
}
Expand Down
77 changes: 61 additions & 16 deletions crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ use iter_extended::vecmap;
use noirc_errors::Location;
use noirc_frontend::monomorphization::ast::{self, LocalId, Parameters};
use noirc_frontend::monomorphization::ast::{FuncId, Program};
use noirc_frontend::Signedness;
use noirc_frontend::{BinaryOpKind, Signedness};

use crate::ssa_refactor::ir::dfg::DataFlowGraph;
use crate::ssa_refactor::ir::function::FunctionId as IrFunctionId;
use crate::ssa_refactor::ir::function::{Function, RuntimeType};
use crate::ssa_refactor::ir::instruction::BinaryOp;
use crate::ssa_refactor::ir::instruction::{BinaryOp, Endian, Intrinsic};
use crate::ssa_refactor::ir::map::AtomicCounter;
use crate::ssa_refactor::ir::types::{NumericType, Type};
use crate::ssa_refactor::ir::value::ValueId;
Expand Down Expand Up @@ -236,6 +236,46 @@ impl<'a> FunctionContext<'a> {
Values::empty()
}

/// Insert ssa instructions which computes lhs << rhs by doing lhs*2^rhs
fn insert_shift_left(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
let base = self.builder.field_constant(FieldElement::from(2_u128));
let pow = self.pow(base, rhs);
self.builder.insert_binary(lhs, BinaryOp::Mul, pow)
}

/// Insert ssa instructions which computes lhs << rhs by doing lhs/2^rhs
fn insert_shift_right(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
let base = self.builder.field_constant(FieldElement::from(2_u128));
let pow = self.pow(base, rhs);
self.builder.insert_binary(lhs, BinaryOp::Div, pow)
}

/// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs
fn pow(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
let typ = self.builder.current_function.dfg.type_of_value(rhs);
if let Type::Numeric(NumericType::Unsigned { bit_size }) = typ {
let to_bits = self.builder.import_intrinsic_id(Intrinsic::ToBits(Endian::Little));
let length = self.builder.field_constant(FieldElement::from(bit_size as i128));
let result_types = vec![Type::Array(Rc::new(vec![Type::bool()]), bit_size as usize)];
let rhs_bits = self.builder.insert_call(to_bits, vec![rhs, length], result_types)[0];
let one = self.builder.field_constant(FieldElement::one());
let mut r = one;
for i in 1..bit_size + 1 {
let r1 = self.builder.insert_binary(r, BinaryOp::Mul, r);
let a = self.builder.insert_binary(r1, BinaryOp::Mul, lhs);
let idx = self.builder.field_constant(FieldElement::from((bit_size - i) as i128));
let b = self.builder.insert_array_get(rhs_bits, idx, Type::field());
let r2 = self.builder.insert_binary(a, BinaryOp::Mul, b);
let c = self.builder.insert_binary(one, BinaryOp::Sub, b);
let r3 = self.builder.insert_binary(c, BinaryOp::Mul, r1);
r = self.builder.insert_binary(r2, BinaryOp::Add, r3);
}
r
} else {
unreachable!("Value must be unsigned in power operation");
}
}

/// Insert a binary instruction at the end of the current block.
/// Converts the form of the binary instruction as necessary
/// (e.g. swapping arguments, inserting a not) to represent it in the IR.
Expand All @@ -247,17 +287,22 @@ impl<'a> FunctionContext<'a> {
mut rhs: ValueId,
location: Location,
) -> Values {
let op = convert_operator(operator);

if op == BinaryOp::Eq && matches!(self.builder.type_of_value(lhs), Type::Array(..)) {
return self.insert_array_equality(lhs, operator, rhs, location);
}

if operator_requires_swapped_operands(operator) {
std::mem::swap(&mut lhs, &mut rhs);
}

let mut result = self.builder.set_location(location).insert_binary(lhs, op, rhs);
let mut result = match operator {
BinaryOpKind::ShiftLeft => self.insert_shift_left(lhs, rhs),
BinaryOpKind::ShiftRight => self.insert_shift_right(lhs, rhs),
BinaryOpKind::Equal | BinaryOpKind::NotEqual
if matches!(self.builder.type_of_value(lhs), Type::Array(..)) =>
{
return self.insert_array_equality(lhs, operator, rhs, location)
}
_ => {
let op = convert_operator(operator);
if operator_requires_swapped_operands(operator) {
std::mem::swap(&mut lhs, &mut rhs);
}
self.builder.set_location(location).insert_binary(lhs, op, rhs)
}
};

if let Some(max_bit_size) = operator_result_max_bit_size_to_truncate(
operator,
Expand Down Expand Up @@ -704,7 +749,6 @@ fn operator_result_max_bit_size_to_truncate(
/// checking operator_requires_not and operator_requires_swapped_operands
/// to represent the full operation correctly.
fn convert_operator(op: noirc_frontend::BinaryOpKind) -> BinaryOp {
use noirc_frontend::BinaryOpKind;
match op {
BinaryOpKind::Add => BinaryOp::Add,
BinaryOpKind::Subtract => BinaryOp::Sub,
Expand All @@ -720,8 +764,9 @@ fn convert_operator(op: noirc_frontend::BinaryOpKind) -> BinaryOp {
BinaryOpKind::And => BinaryOp::And,
BinaryOpKind::Or => BinaryOp::Or,
BinaryOpKind::Xor => BinaryOp::Xor,
BinaryOpKind::ShiftRight => BinaryOp::Shr,
BinaryOpKind::ShiftLeft => BinaryOp::Shl,
BinaryOpKind::ShiftRight | BinaryOpKind::ShiftLeft => unreachable!(
"ICE - bit shift operators do not exist in SSA and should have been replaced"
),
}
}

Expand Down
4 changes: 4 additions & 0 deletions crates/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ impl BinaryOpKind {
BinaryOpKind::Modulo => Token::Percent,
}
}

pub fn is_bit_shift(&self) -> bool {
matches!(self, BinaryOpKind::ShiftRight | BinaryOpKind::ShiftLeft)
}
}

#[derive(PartialEq, PartialOrd, Eq, Ord, Hash, Debug, Copy, Clone)]
Expand Down
Loading

0 comments on commit b0fbc53

Please sign in to comment.