From eb46566c6e8d03b6a1a3f1d20d2b6bb310d7d182 Mon Sep 17 00:00:00 2001 From: joss-aztec <94053499+joss-aztec@users.noreply.github.com> Date: Fri, 9 Jun 2023 19:49:05 +0100 Subject: [PATCH] chore(ssa gen): ssa gen truncate instruction (#1568) * chore(ssa gen): ssa gen truncate instruction * chore(ssa refactor): max bit size for subtract * Update crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs Co-authored-by: jfecher * chore(ssa refactor): truncate shift left * chore(ssa refactor): Add integer modulus when truncating subtraction * chore(ssa refactor): clippy * chore(ssa refactor): fix left shift max bit size * chore(ssa refactor): brillig gen truncate * chore(ssa refactor): truncate const folding * chore(ssa refactor): adaptive left shift max bit size --------- Co-authored-by: kevaundray Co-authored-by: jfecher --- .../src/brillig/brillig_gen.rs | 8 ++ .../src/ssa_refactor/acir_gen/mod.rs | 31 ++++++-- .../src/ssa_refactor/ir/instruction.rs | 10 ++- .../src/ssa_refactor/ssa_builder/mod.rs | 12 +++ .../src/ssa_refactor/ssa_gen/context.rs | 79 ++++++++++++++++++- 5 files changed, 133 insertions(+), 7 deletions(-) diff --git a/crates/noirc_evaluator/src/brillig/brillig_gen.rs b/crates/noirc_evaluator/src/brillig/brillig_gen.rs index 519e341b2b..fde751b16c 100644 --- a/crates/noirc_evaluator/src/brillig/brillig_gen.rs +++ b/crates/noirc_evaluator/src/brillig/brillig_gen.rs @@ -205,6 +205,14 @@ impl BrilligGen { }; self.push_code(opcode); } + Instruction::Truncate { value, .. } => { + // Effectively a no-op because brillig already has implicit truncation on integer + // operations. We need only copy the value to it's destination. + let result_ids = dfg.instruction_results(instruction_id); + let destination = self.get_or_create_register(result_ids[0]); + let source = self.convert_ssa_value(*value, dfg); + self.push_code(BrilligOpcode::Mov { destination, source }); + } _ => todo!("ICE: Instruction not supported {instruction:?}"), }; } diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs index 12c1fde04c..4148ee3108 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs @@ -212,13 +212,9 @@ impl Context { self.define_result_var(dfg, instruction_id, result_acir_var); } Instruction::Truncate { value, bit_size, max_bit_size } => { - let var = self.convert_numeric_value(*value, dfg); - let result_acir_var = self - .acir_context - .truncate_var(var, *bit_size, *max_bit_size) + .convert_ssa_truncate(*value, *bit_size, *max_bit_size, dfg) .expect("add Result types to all methods so errors bubble up"); - self.define_result_var(dfg, instruction_id, result_acir_var); } Instruction::ArrayGet { array, index } => { @@ -479,6 +475,31 @@ impl Context { } } + /// Returns an `AcirVar`that is constrained to be result of the truncation. + fn convert_ssa_truncate( + &mut self, + value_id: ValueId, + bit_size: u32, + max_bit_size: u32, + dfg: &DataFlowGraph, + ) -> Result { + let mut var = self.convert_numeric_value(value_id, dfg); + let truncation_target = match &dfg[value_id] { + Value::Instruction { instruction, .. } => &dfg[*instruction], + _ => unreachable!("ICE: Truncates are only ever applied to the result of a binary op"), + }; + if matches!(truncation_target, Instruction::Binary(Binary { operator: BinaryOp::Sub, .. })) + { + // Subtractions must first have the integer modulus added before truncation can be + // applied. This is done in order to prevent underflow. + let integer_modulus = + self.acir_context.add_constant(FieldElement::from(2_u128.pow(bit_size))); + var = self.acir_context.add_var(var, integer_modulus)?; + } + + self.acir_context.truncate_var(var, bit_size, max_bit_size) + } + /// Returns a vector of `AcirVar`s constrained to be result of the function call. /// /// The function being called is required to be intrinsic. diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs index a9767bc377..d43e22df4d 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/instruction.rs @@ -243,7 +243,15 @@ impl Instruction { None } } - Instruction::Truncate { .. } => None, + Instruction::Truncate { value, bit_size, .. } => { + if let Some((numeric_constant, typ)) = dfg.get_numeric_constant_with_type(*value) { + let integer_modulus = 2_u128.pow(*bit_size); + let truncated = numeric_constant.to_u128() % integer_modulus; + SimplifiedTo(dfg.make_constant(truncated.into(), typ)) + } else { + None + } + } Instruction::Call { .. } => None, Instruction::Allocate { .. } => None, Instruction::Load { .. } => None, diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs index 902278e286..49e68dad1f 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_builder/mod.rs @@ -212,6 +212,18 @@ impl FunctionBuilder { self.insert_instruction(Instruction::Cast(value, typ), None).first() } + /// Insert a truncate instruction at the end of the current block. + /// Returns the result of the truncate instruction. + pub(crate) fn insert_truncate( + &mut self, + value: ValueId, + bit_size: u32, + max_bit_size: u32, + ) -> ValueId { + self.insert_instruction(Instruction::Truncate { value, bit_size, max_bit_size }, None) + .first() + } + /// Insert a constrain instruction at the end of the current block. pub(crate) fn insert_constrain(&mut self, boolean: ValueId) { self.insert_instruction(Instruction::Constrain(boolean), None); diff --git a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs index e20a54ba8c..3f8b6f3885 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs @@ -2,16 +2,18 @@ use std::collections::HashMap; use std::rc::Rc; use std::sync::{Mutex, RwLock}; +use acvm::FieldElement; use iter_extended::vecmap; use noirc_frontend::monomorphization::ast::{self, LocalId, Parameters}; use noirc_frontend::monomorphization::ast::{FuncId, Program}; use noirc_frontend::Signedness; +use crate::ssa_refactor::ir::dfg::DataFlowGraph; use crate::ssa_refactor::ir::function::FunctionId as IrFunctionId; use crate::ssa_refactor::ir::function::{Function, RuntimeType}; use crate::ssa_refactor::ir::instruction::BinaryOp; use crate::ssa_refactor::ir::map::AtomicCounter; -use crate::ssa_refactor::ir::types::Type; +use crate::ssa_refactor::ir::types::{NumericType, Type}; use crate::ssa_refactor::ir::value::ValueId; use crate::ssa_refactor::ssa_builder::FunctionBuilder; @@ -230,6 +232,23 @@ impl<'a> FunctionContext<'a> { let mut result = self.builder.insert_binary(lhs, op, rhs); + if let Some(max_bit_size) = operator_result_max_bit_size_to_truncate( + operator, + lhs, + rhs, + &self.builder.current_function.dfg, + ) { + let result_type = self.builder.current_function.dfg.type_of_value(result); + let bit_size = match result_type { + Type::Numeric(NumericType::Signed { bit_size }) + | Type::Numeric(NumericType::Unsigned { bit_size }) => bit_size, + _ => { + unreachable!("ICE: Truncation attempted on non-integer"); + } + }; + result = self.builder.insert_truncate(result, bit_size, max_bit_size); + } + if operator_requires_not(operator) { result = self.builder.insert_not(result); } @@ -470,6 +489,64 @@ fn operator_requires_swapped_operands(op: noirc_frontend::BinaryOpKind) -> bool matches!(op, Greater | LessEqual) } +/// If the operation requires its result to be truncated because it is an integer, the maximum +/// number of bits that result may occupy is returned. +fn operator_result_max_bit_size_to_truncate( + op: noirc_frontend::BinaryOpKind, + lhs: ValueId, + rhs: ValueId, + dfg: &DataFlowGraph, +) -> Option { + let lhs_type = dfg.type_of_value(lhs); + let rhs_type = dfg.type_of_value(rhs); + + let get_bit_size = |typ| match typ { + Type::Numeric(NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size }) => { + Some(bit_size) + } + _ => None, + }; + + let lhs_bit_size = get_bit_size(lhs_type)?; + let rhs_bit_size = get_bit_size(rhs_type)?; + use noirc_frontend::BinaryOpKind::*; + match op { + Add => Some(std::cmp::max(lhs_bit_size, rhs_bit_size) + 1), + Subtract => Some(std::cmp::max(lhs_bit_size, rhs_bit_size) + 1), + Multiply => Some(lhs_bit_size + rhs_bit_size), + ShiftLeft => { + if let Some(rhs_constant) = dfg.get_numeric_constant(rhs) { + // Happy case is that we know precisely by how many bits the the integer will + // increase: lhs_bit_size + rhs + return Some(lhs_bit_size + (rhs_constant.to_u128() as u32)); + } + // Unhappy case is that we don't yet know the rhs value, (even though it will + // eventually have to resolve to a constant). The best we can is assume the value of + // rhs to be the maximum value of it's numeric type. If that turns out to be larger + // than the native field's bit size, we full back to using that. + + // The formula for calculating the max bit size of a left shift is: + // lhs_bit_size + 2^{rhs_bit_size} - 1 + // Inferring the max bit size of left shift from its operands can result in huge + // number, that might not only be larger than the native field's max bit size, but + // furthermore might not be representable as a u32. Hence we use overflow checks and + // fallback to the native field's max bits. + let field_max_bits = FieldElement::max_num_bits(); + let (rhs_bit_size_pow_2, overflows) = 2_u32.overflowing_pow(rhs_bit_size); + if overflows { + return Some(field_max_bits); + } + let (max_bits_plus_1, overflows) = rhs_bit_size_pow_2.overflowing_add(lhs_bit_size); + if overflows { + return Some(field_max_bits); + } + let max_bit_size = std::cmp::min(max_bits_plus_1 - 1, field_max_bits); + Some(max_bit_size) + } + _ => None, + } +} + /// Converts the given operator to the appropriate BinaryOp. /// Take care when using this to insert a binary instruction: this requires /// checking operator_requires_not and operator_requires_swapped_operands