Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(ssa gen): ssa gen truncate instruction #1568

Merged
merged 12 commits into from
Jun 9, 2023
Merged
8 changes: 8 additions & 0 deletions crates/noirc_evaluator/src/brillig/brillig_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,14 @@ impl BrilligGen {
};
self.push_code(opcode);
}
Instruction::Truncate { value, .. } => {
// Effectively a no-op because brillig already has implicit truncation on integer
// operations. We need only copy the value to it's destination.
let result_ids = dfg.instruction_results(instruction_id);
let destination = self.get_or_create_register(result_ids[0]);
let source = self.convert_ssa_value(*value, dfg);
self.push_code(BrilligOpcode::Mov { destination, source });
}
_ => todo!("ICE: Instruction not supported {instruction:?}"),
};
}
Expand Down
31 changes: 26 additions & 5 deletions crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,9 @@ impl Context {
self.define_result_var(dfg, instruction_id, result_acir_var);
}
Instruction::Truncate { value, bit_size, max_bit_size } => {
let var = self.convert_numeric_value(*value, dfg);

let result_acir_var = self
.acir_context
.truncate_var(var, *bit_size, *max_bit_size)
.convert_ssa_truncate(*value, *bit_size, *max_bit_size, dfg)
.expect("add Result types to all methods so errors bubble up");

self.define_result_var(dfg, instruction_id, result_acir_var);
}
Instruction::ArrayGet { array, index } => {
Expand Down Expand Up @@ -479,6 +475,31 @@ impl Context {
}
}

/// Returns an `AcirVar`that is constrained to be result of the truncation.
fn convert_ssa_truncate(
&mut self,
value_id: ValueId,
bit_size: u32,
max_bit_size: u32,
dfg: &DataFlowGraph,
) -> Result<AcirVar, AcirGenError> {
let mut var = self.convert_numeric_value(value_id, dfg);
let truncation_target = match &dfg[value_id] {
Value::Instruction { instruction, .. } => &dfg[*instruction],
_ => unreachable!("ICE: Truncates are only ever applied to the result of a binary op"),
};
if matches!(truncation_target, Instruction::Binary(Binary { operator: BinaryOp::Sub, .. }))
{
// Subtractions must first have the integer modulus added before truncation can be
// applied. This is done in order to prevent underflow.
let integer_modulus =
self.acir_context.add_constant(FieldElement::from(2_u128.pow(bit_size)));
var = self.acir_context.add_var(var, integer_modulus)?;
}

self.acir_context.truncate_var(var, bit_size, max_bit_size)
}

/// Returns a vector of `AcirVar`s constrained to be result of the function call.
///
/// The function being called is required to be intrinsic.
Expand Down
10 changes: 9 additions & 1 deletion crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,15 @@ impl Instruction {
None
}
}
Instruction::Truncate { .. } => None,
Instruction::Truncate { value, bit_size, .. } => {
if let Some((numeric_constant, typ)) = dfg.get_numeric_constant_with_type(*value) {
let integer_modulus = 2_u128.pow(*bit_size);
let truncated = numeric_constant.to_u128() % integer_modulus;
SimplifiedTo(dfg.make_constant(truncated.into(), typ))
} else {
None
}
}
Instruction::Call { .. } => None,
Instruction::Allocate { .. } => None,
Instruction::Load { .. } => None,
Expand Down
12 changes: 12 additions & 0 deletions crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,18 @@ impl FunctionBuilder {
self.insert_instruction(Instruction::Cast(value, typ), None).first()
}

/// Insert a truncate instruction at the end of the current block.
/// Returns the result of the truncate instruction.
pub(crate) fn insert_truncate(
&mut self,
value: ValueId,
bit_size: u32,
max_bit_size: u32,
) -> ValueId {
self.insert_instruction(Instruction::Truncate { value, bit_size, max_bit_size }, None)
.first()
}

/// Insert a constrain instruction at the end of the current block.
pub(crate) fn insert_constrain(&mut self, boolean: ValueId) {
self.insert_instruction(Instruction::Constrain(boolean), None);
Expand Down
79 changes: 78 additions & 1 deletion crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ use std::collections::HashMap;
use std::rc::Rc;
use std::sync::{Mutex, RwLock};

use acvm::FieldElement;
use iter_extended::vecmap;
use noirc_frontend::monomorphization::ast::{self, LocalId, Parameters};
use noirc_frontend::monomorphization::ast::{FuncId, Program};
use noirc_frontend::Signedness;

use crate::ssa_refactor::ir::dfg::DataFlowGraph;
use crate::ssa_refactor::ir::function::FunctionId as IrFunctionId;
use crate::ssa_refactor::ir::function::{Function, RuntimeType};
use crate::ssa_refactor::ir::instruction::BinaryOp;
use crate::ssa_refactor::ir::map::AtomicCounter;
use crate::ssa_refactor::ir::types::Type;
use crate::ssa_refactor::ir::types::{NumericType, Type};
use crate::ssa_refactor::ir::value::ValueId;
use crate::ssa_refactor::ssa_builder::FunctionBuilder;

Expand Down Expand Up @@ -230,6 +232,23 @@ impl<'a> FunctionContext<'a> {

let mut result = self.builder.insert_binary(lhs, op, rhs);

if let Some(max_bit_size) = operator_result_max_bit_size_to_truncate(
operator,
lhs,
rhs,
&self.builder.current_function.dfg,
) {
let result_type = self.builder.current_function.dfg.type_of_value(result);
let bit_size = match result_type {
Type::Numeric(NumericType::Signed { bit_size })
| Type::Numeric(NumericType::Unsigned { bit_size }) => bit_size,
_ => {
unreachable!("ICE: Truncation attempted on non-integer");
}
};
result = self.builder.insert_truncate(result, bit_size, max_bit_size);
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
}

if operator_requires_not(operator) {
result = self.builder.insert_not(result);
}
Expand Down Expand Up @@ -470,6 +489,64 @@ fn operator_requires_swapped_operands(op: noirc_frontend::BinaryOpKind) -> bool
matches!(op, Greater | LessEqual)
}

/// If the operation requires its result to be truncated because it is an integer, the maximum
/// number of bits that result may occupy is returned.
fn operator_result_max_bit_size_to_truncate(
op: noirc_frontend::BinaryOpKind,
lhs: ValueId,
rhs: ValueId,
dfg: &DataFlowGraph,
) -> Option<u32> {
let lhs_type = dfg.type_of_value(lhs);
let rhs_type = dfg.type_of_value(rhs);

let get_bit_size = |typ| match typ {
Type::Numeric(NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size }) => {
Some(bit_size)
}
_ => None,
};

let lhs_bit_size = get_bit_size(lhs_type)?;
let rhs_bit_size = get_bit_size(rhs_type)?;
use noirc_frontend::BinaryOpKind::*;
match op {
Add => Some(std::cmp::max(lhs_bit_size, rhs_bit_size) + 1),
Subtract => Some(std::cmp::max(lhs_bit_size, rhs_bit_size) + 1),
Multiply => Some(lhs_bit_size + rhs_bit_size),
ShiftLeft => {
if let Some(rhs_constant) = dfg.get_numeric_constant(rhs) {
// Happy case is that we know precisely by how many bits the the integer will
// increase: lhs_bit_size + rhs
return Some(lhs_bit_size + (rhs_constant.to_u128() as u32));
}
// Unhappy case is that we don't yet know the rhs value, (even though it will
// eventually have to resolve to a constant). The best we can is assume the value of
// rhs to be the maximum value of it's numeric type. If that turns out to be larger
// than the native field's bit size, we full back to using that.

// The formula for calculating the max bit size of a left shift is:
// lhs_bit_size + 2^{rhs_bit_size} - 1
// Inferring the max bit size of left shift from its operands can result in huge
// number, that might not only be larger than the native field's max bit size, but
// furthermore might not be representable as a u32. Hence we use overflow checks and
// fallback to the native field's max bits.
let field_max_bits = FieldElement::max_num_bits();
let (rhs_bit_size_pow_2, overflows) = 2_u32.overflowing_pow(rhs_bit_size);
if overflows {
return Some(field_max_bits);
}
let (max_bits_plus_1, overflows) = rhs_bit_size_pow_2.overflowing_add(lhs_bit_size);
if overflows {
return Some(field_max_bits);
}
let max_bit_size = std::cmp::min(max_bits_plus_1 - 1, field_max_bits);
Some(max_bit_size)
}
_ => None,
}
}

/// Converts the given operator to the appropriate BinaryOp.
/// Take care when using this to insert a binary instruction: this requires
/// checking operator_requires_not and operator_requires_swapped_operands
Expand Down