Skip to content

Commit

Permalink
chore(ssa gen): ssa gen truncate instruction (#1568)
Browse files Browse the repository at this point in the history
* chore(ssa gen): ssa gen truncate instruction

* chore(ssa refactor): max bit size for subtract

* Update crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs

Co-authored-by: jfecher <jake@aztecprotocol.com>

* chore(ssa refactor): truncate shift left

* chore(ssa refactor): Add integer modulus when truncating subtraction

* chore(ssa refactor): clippy

* chore(ssa refactor): fix left shift max bit size

* chore(ssa refactor): brillig gen truncate

* chore(ssa refactor): truncate const folding

* chore(ssa refactor): adaptive left shift max bit size

---------

Co-authored-by: kevaundray <kevtheappdev@gmail.com>
Co-authored-by: jfecher <jake@aztecprotocol.com>
  • Loading branch information
3 people authored Jun 9, 2023
1 parent 76a3def commit eb46566
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 7 deletions.
8 changes: 8 additions & 0 deletions crates/noirc_evaluator/src/brillig/brillig_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,14 @@ impl BrilligGen {
};
self.push_code(opcode);
}
Instruction::Truncate { value, .. } => {
// Effectively a no-op because brillig already has implicit truncation on integer
// operations. We need only copy the value to it's destination.
let result_ids = dfg.instruction_results(instruction_id);
let destination = self.get_or_create_register(result_ids[0]);
let source = self.convert_ssa_value(*value, dfg);
self.push_code(BrilligOpcode::Mov { destination, source });
}
_ => todo!("ICE: Instruction not supported {instruction:?}"),
};
}
Expand Down
31 changes: 26 additions & 5 deletions crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,9 @@ impl Context {
self.define_result_var(dfg, instruction_id, result_acir_var);
}
Instruction::Truncate { value, bit_size, max_bit_size } => {
let var = self.convert_numeric_value(*value, dfg);

let result_acir_var = self
.acir_context
.truncate_var(var, *bit_size, *max_bit_size)
.convert_ssa_truncate(*value, *bit_size, *max_bit_size, dfg)
.expect("add Result types to all methods so errors bubble up");

self.define_result_var(dfg, instruction_id, result_acir_var);
}
Instruction::ArrayGet { array, index } => {
Expand Down Expand Up @@ -479,6 +475,31 @@ impl Context {
}
}

/// Returns an `AcirVar`that is constrained to be result of the truncation.
fn convert_ssa_truncate(
&mut self,
value_id: ValueId,
bit_size: u32,
max_bit_size: u32,
dfg: &DataFlowGraph,
) -> Result<AcirVar, AcirGenError> {
let mut var = self.convert_numeric_value(value_id, dfg);
let truncation_target = match &dfg[value_id] {
Value::Instruction { instruction, .. } => &dfg[*instruction],
_ => unreachable!("ICE: Truncates are only ever applied to the result of a binary op"),
};
if matches!(truncation_target, Instruction::Binary(Binary { operator: BinaryOp::Sub, .. }))
{
// Subtractions must first have the integer modulus added before truncation can be
// applied. This is done in order to prevent underflow.
let integer_modulus =
self.acir_context.add_constant(FieldElement::from(2_u128.pow(bit_size)));
var = self.acir_context.add_var(var, integer_modulus)?;
}

self.acir_context.truncate_var(var, bit_size, max_bit_size)
}

/// Returns a vector of `AcirVar`s constrained to be result of the function call.
///
/// The function being called is required to be intrinsic.
Expand Down
10 changes: 9 additions & 1 deletion crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,15 @@ impl Instruction {
None
}
}
Instruction::Truncate { .. } => None,
Instruction::Truncate { value, bit_size, .. } => {
if let Some((numeric_constant, typ)) = dfg.get_numeric_constant_with_type(*value) {
let integer_modulus = 2_u128.pow(*bit_size);
let truncated = numeric_constant.to_u128() % integer_modulus;
SimplifiedTo(dfg.make_constant(truncated.into(), typ))
} else {
None
}
}
Instruction::Call { .. } => None,
Instruction::Allocate { .. } => None,
Instruction::Load { .. } => None,
Expand Down
12 changes: 12 additions & 0 deletions crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,18 @@ impl FunctionBuilder {
self.insert_instruction(Instruction::Cast(value, typ), None).first()
}

/// Insert a truncate instruction at the end of the current block.
/// Returns the result of the truncate instruction.
pub(crate) fn insert_truncate(
&mut self,
value: ValueId,
bit_size: u32,
max_bit_size: u32,
) -> ValueId {
self.insert_instruction(Instruction::Truncate { value, bit_size, max_bit_size }, None)
.first()
}

/// Insert a constrain instruction at the end of the current block.
pub(crate) fn insert_constrain(&mut self, boolean: ValueId) {
self.insert_instruction(Instruction::Constrain(boolean), None);
Expand Down
79 changes: 78 additions & 1 deletion crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ use std::collections::HashMap;
use std::rc::Rc;
use std::sync::{Mutex, RwLock};

use acvm::FieldElement;
use iter_extended::vecmap;
use noirc_frontend::monomorphization::ast::{self, LocalId, Parameters};
use noirc_frontend::monomorphization::ast::{FuncId, Program};
use noirc_frontend::Signedness;

use crate::ssa_refactor::ir::dfg::DataFlowGraph;
use crate::ssa_refactor::ir::function::FunctionId as IrFunctionId;
use crate::ssa_refactor::ir::function::{Function, RuntimeType};
use crate::ssa_refactor::ir::instruction::BinaryOp;
use crate::ssa_refactor::ir::map::AtomicCounter;
use crate::ssa_refactor::ir::types::Type;
use crate::ssa_refactor::ir::types::{NumericType, Type};
use crate::ssa_refactor::ir::value::ValueId;
use crate::ssa_refactor::ssa_builder::FunctionBuilder;

Expand Down Expand Up @@ -230,6 +232,23 @@ impl<'a> FunctionContext<'a> {

let mut result = self.builder.insert_binary(lhs, op, rhs);

if let Some(max_bit_size) = operator_result_max_bit_size_to_truncate(
operator,
lhs,
rhs,
&self.builder.current_function.dfg,
) {
let result_type = self.builder.current_function.dfg.type_of_value(result);
let bit_size = match result_type {
Type::Numeric(NumericType::Signed { bit_size })
| Type::Numeric(NumericType::Unsigned { bit_size }) => bit_size,
_ => {
unreachable!("ICE: Truncation attempted on non-integer");
}
};
result = self.builder.insert_truncate(result, bit_size, max_bit_size);
}

if operator_requires_not(operator) {
result = self.builder.insert_not(result);
}
Expand Down Expand Up @@ -470,6 +489,64 @@ fn operator_requires_swapped_operands(op: noirc_frontend::BinaryOpKind) -> bool
matches!(op, Greater | LessEqual)
}

/// If the operation requires its result to be truncated because it is an integer, the maximum
/// number of bits that result may occupy is returned.
fn operator_result_max_bit_size_to_truncate(
op: noirc_frontend::BinaryOpKind,
lhs: ValueId,
rhs: ValueId,
dfg: &DataFlowGraph,
) -> Option<u32> {
let lhs_type = dfg.type_of_value(lhs);
let rhs_type = dfg.type_of_value(rhs);

let get_bit_size = |typ| match typ {
Type::Numeric(NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size }) => {
Some(bit_size)
}
_ => None,
};

let lhs_bit_size = get_bit_size(lhs_type)?;
let rhs_bit_size = get_bit_size(rhs_type)?;
use noirc_frontend::BinaryOpKind::*;
match op {
Add => Some(std::cmp::max(lhs_bit_size, rhs_bit_size) + 1),
Subtract => Some(std::cmp::max(lhs_bit_size, rhs_bit_size) + 1),
Multiply => Some(lhs_bit_size + rhs_bit_size),
ShiftLeft => {
if let Some(rhs_constant) = dfg.get_numeric_constant(rhs) {
// Happy case is that we know precisely by how many bits the the integer will
// increase: lhs_bit_size + rhs
return Some(lhs_bit_size + (rhs_constant.to_u128() as u32));
}
// Unhappy case is that we don't yet know the rhs value, (even though it will
// eventually have to resolve to a constant). The best we can is assume the value of
// rhs to be the maximum value of it's numeric type. If that turns out to be larger
// than the native field's bit size, we full back to using that.

// The formula for calculating the max bit size of a left shift is:
// lhs_bit_size + 2^{rhs_bit_size} - 1
// Inferring the max bit size of left shift from its operands can result in huge
// number, that might not only be larger than the native field's max bit size, but
// furthermore might not be representable as a u32. Hence we use overflow checks and
// fallback to the native field's max bits.
let field_max_bits = FieldElement::max_num_bits();
let (rhs_bit_size_pow_2, overflows) = 2_u32.overflowing_pow(rhs_bit_size);
if overflows {
return Some(field_max_bits);
}
let (max_bits_plus_1, overflows) = rhs_bit_size_pow_2.overflowing_add(lhs_bit_size);
if overflows {
return Some(field_max_bits);
}
let max_bit_size = std::cmp::min(max_bits_plus_1 - 1, field_max_bits);
Some(max_bit_size)
}
_ => None,
}
}

/// Converts the given operator to the appropriate BinaryOp.
/// Take care when using this to insert a binary instruction: this requires
/// checking operator_requires_not and operator_requires_swapped_operands
Expand Down

0 comments on commit eb46566

Please sign in to comment.