Skip to content

Commit

Permalink
feat: add support for u1 in the avm circuit & witgen
Browse files Browse the repository at this point in the history
  • Loading branch information
dbanks12 committed Sep 17, 2024
1 parent 07e6a7e commit 62c53fc
Show file tree
Hide file tree
Showing 59 changed files with 2,106 additions and 1,506 deletions.
10 changes: 5 additions & 5 deletions avm-transpiler/src/bit_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ impl BitsQueryable for usize {

pub fn bits_needed_for<T: BitsQueryable>(val: &T) -> usize {
let num_bits = val.num_bits();
if num_bits < 8 {
if num_bits <= 8 {
8
} else if num_bits < 16 {
} else if num_bits <= 16 {
16
} else if num_bits < 32 {
} else if num_bits <= 32 {
32
} else if num_bits < 64 {
} else if num_bits <= 64 {
64
} else if num_bits < 128 {
} else if num_bits <= 128 {
128
} else {
254
Expand Down
4 changes: 4 additions & 0 deletions avm-transpiler/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ impl Default for AvmInstruction {
#[derive(Copy, Clone, Debug)]
pub enum AvmTypeTag {
UNINITIALIZED,
UINT1,
UINT8,
UINT16,
UINT32,
Expand All @@ -107,6 +108,7 @@ pub enum AvmTypeTag {
/// Constants (as used by the SET instruction) can have size
/// different from 32 bits
pub enum AvmOperand {
U1 { value: u8 }, // same wire format as U8
U8 { value: u8 },
U16 { value: u16 },
U32 { value: u32 },
Expand All @@ -118,6 +120,7 @@ pub enum AvmOperand {
impl Display for AvmOperand {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
AvmOperand::U1 { value } => write!(f, " U1:{}", value),
AvmOperand::U8 { value } => write!(f, " U8:{}", value),
AvmOperand::U16 { value } => write!(f, " U16:{}", value),
AvmOperand::U32 { value } => write!(f, " U32:{}", value),
Expand All @@ -131,6 +134,7 @@ impl Display for AvmOperand {
impl AvmOperand {
pub fn to_be_bytes(&self) -> Vec<u8> {
match self {
AvmOperand::U1 { value } => value.to_be_bytes().to_vec(),
AvmOperand::U8 { value } => value.to_be_bytes().to_vec(),
AvmOperand::U16 { value } => value.to_be_bytes().to_vec(),
AvmOperand::U32 { value } => value.to_be_bytes().to_vec(),
Expand Down
49 changes: 7 additions & 42 deletions avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,6 @@ pub fn brillig_to_avm(
],
tag: Some(tag_from_bit_size(BitSize::Integer(*bit_size))),
});
if let IntegerBitSize::U1 = bit_size {
// We need to cast the result back to u1
handle_cast(
&mut avm_instrs,
destination,
destination,
BitSize::Integer(IntegerBitSize::U1),
);
}
}
BrilligOpcode::CalldataCopy { destination_address, size_address, offset_address } => {
avm_instrs.push(AvmInstruction {
Expand Down Expand Up @@ -505,32 +496,8 @@ fn handle_cast(
let source_offset = source.to_usize() as u32;
let dest_offset = destination.to_usize() as u32;

if bit_size == BitSize::Integer(IntegerBitSize::U1) {
assert!(
matches!(tag_from_bit_size(bit_size), AvmTypeTag::UINT8),
"If u1 doesn't map to u8 anymore, change this code!"
);
avm_instrs.extend([
// We cast to Field to be able to use toradix.
generate_cast_instruction(source_offset, false, dest_offset, false, AvmTypeTag::FIELD),
// Toradix with radix 2 and 1 limb is the same as modulo 2.
// We need to insert an instruction explicitly because we want to fine-tune 'indirect'.
AvmInstruction {
opcode: AvmOpcode::TORADIXLE,
indirect: Some(ALL_DIRECT),
tag: None,
operands: vec![
AvmOperand::U32 { value: dest_offset },
AvmOperand::U32 { value: dest_offset },
AvmOperand::U32 { value: /*radix=*/ 2},
AvmOperand::U32 { value: /*limbs=*/ 1},
],
},
]);
} else {
let tag = tag_from_bit_size(bit_size);
avm_instrs.push(generate_cast_instruction(source_offset, false, dest_offset, false, tag));
}
let tag = tag_from_bit_size(bit_size);
avm_instrs.push(generate_cast_instruction(source_offset, false, dest_offset, false, tag));
}

/// Handle an AVM NOTEHASHEXISTS instruction
Expand Down Expand Up @@ -987,12 +954,11 @@ fn handle_black_box_function(avm_instrs: &mut Vec<AvmInstruction>, operation: &B
..Default::default()
});
}
// We ignore the output bits flag since we represent bits as bytes
BlackBoxOp::ToRadix { input, radix, output, output_bits: _ } => {
BlackBoxOp::ToRadix { input, radix, output, output_bits } => {
let num_limbs = output.size as u32;
let input_offset = input.0 as u32;
let output_offset = output.pointer.0 as u32;
assert!(radix <= &256u32, "Radix must be less than or equal to 256");
let radix_offset = radix.0 as u32;

avm_instrs.push(AvmInstruction {
opcode: AvmOpcode::TORADIXLE,
Expand All @@ -1001,8 +967,9 @@ fn handle_black_box_function(avm_instrs: &mut Vec<AvmInstruction>, operation: &B
operands: vec![
AvmOperand::U32 { value: input_offset },
AvmOperand::U32 { value: output_offset },
AvmOperand::U32 { value: *radix },
AvmOperand::U32 { value: radix_offset },
AvmOperand::U32 { value: num_limbs },
AvmOperand::U1 { value: *output_bits as u8 },
],
});
}
Expand Down Expand Up @@ -1313,8 +1280,6 @@ pub fn map_brillig_pcs_to_avm_pcs(brillig_bytecode: &[BrilligOpcode<FieldElement
pc_map[0] = 0; // first PC is always 0 as there are no instructions inserted by AVM at start
for i in 0..brillig_bytecode.len() - 1 {
let num_avm_instrs_for_this_brillig_instr = match &brillig_bytecode[i] {
BrilligOpcode::Cast { bit_size: BitSize::Integer(IntegerBitSize::U1), .. } => 2,
BrilligOpcode::Not { bit_size: IntegerBitSize::U1, .. } => 3,
_ => 1,
};
// next Brillig pc will map to an AVM pc offset by the
Expand All @@ -1338,7 +1303,7 @@ fn is_integral_bit_size(bit_size: IntegerBitSize) -> bool {

fn tag_from_bit_size(bit_size: BitSize) -> AvmTypeTag {
match bit_size {
BitSize::Integer(IntegerBitSize::U1) => AvmTypeTag::UINT8, // temp workaround
BitSize::Integer(IntegerBitSize::U1) => AvmTypeTag::UINT1,
BitSize::Integer(IntegerBitSize::U8) => AvmTypeTag::UINT8,
BitSize::Integer(IntegerBitSize::U16) => AvmTypeTag::UINT16,
BitSize::Integer(IntegerBitSize::U32) => AvmTypeTag::UINT32,
Expand Down
30 changes: 21 additions & 9 deletions barretenberg/cpp/pil/avm/alu.pil
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include "constants_gen.pil";
include "gadgets/range_check.pil";
include "gadgets/cmp.pil";
namespace alu(256);
Expand All @@ -12,16 +13,17 @@ namespace alu(256);
pol commit ic;
pol commit sel_alu; // Predicate to activate the copy of intermediate registers to ALU table.

// Instruction tag (1: u8, 2: u16, 3: u32, 4: u64, 5: u128, 6: field) copied from Main table
// Instruction tag (1: u1, 2: u8, 3: u16, 4: u32, 5: u64, 6: u128, 7: field) copied from Main table
pol commit in_tag;

// Flattened boolean instruction tags
pol commit ff_tag;
pol commit u1_tag;
pol commit u8_tag;
pol commit u16_tag;
pol commit u32_tag;
pol commit u64_tag;
pol commit u128_tag;
pol commit ff_tag;

// Compute predicate telling whether there is a row entry in the ALU table.
sel_alu = op_add + op_sub + op_mul + op_not + op_eq + op_cast + op_lt + op_lte + op_shr + op_shl + op_div;
Expand All @@ -30,18 +32,25 @@ namespace alu(256);
// Remark: Operation selectors are constrained in the main trace.

// Boolean flattened instructions tags
ff_tag * (1 - ff_tag) = 0;
u1_tag * (1 - u1_tag) = 0;
u8_tag * (1 - u8_tag) = 0;
u16_tag * (1 - u16_tag) = 0;
u32_tag * (1 - u32_tag) = 0;
u64_tag * (1 - u64_tag) = 0;
u128_tag * (1 - u128_tag) = 0;
ff_tag * (1 - ff_tag) = 0;

// Mutual exclusion of the flattened instruction tag.
sel_alu * (ff_tag + u8_tag + u16_tag + u32_tag + u64_tag + u128_tag - 1) = 0;
sel_alu * (u1_tag + u8_tag + u16_tag + u32_tag + u64_tag + u128_tag + ff_tag - 1) = 0;

// Correct flattening of the instruction tag.
in_tag = u8_tag + 2 * u16_tag + 3 * u32_tag + 4 * u64_tag + 5 * u128_tag + 6 * ff_tag;
in_tag = u1_tag // * 1 (constants.MEM_TAG_U1)
+ (constants.MEM_TAG_U8 * u8_tag)
+ (constants.MEM_TAG_U16 * u16_tag)
+ (constants.MEM_TAG_U32 * u32_tag)
+ (constants.MEM_TAG_U64 * u64_tag)
+ (constants.MEM_TAG_U128 * u128_tag)
+ (constants.MEM_TAG_FF * ff_tag);

// Operation selectors are copied from main table and do not need to be constrained here.
// Mutual exclusion of op_add and op_sub are derived from their mutual exclusion in the
Expand All @@ -59,7 +68,7 @@ namespace alu(256);
range_check_input_value = (op_add + op_sub + op_mul + op_cast + op_div) * ic + (op_shr * a_hi * NON_TRIVIAL_SHIFT) + (op_shl * a_lo * NON_TRIVIAL_SHIFT);
// The allowed bit range is defined by the instr tag, unless in shifts where it's different
range_check_num_bits =
(op_add + op_sub + op_mul + op_cast + op_div) * (u8_tag * 8 + u16_tag * 16 + u32_tag * 32 + u64_tag * 64 + u128_tag * 128) +
(op_add + op_sub + op_mul + op_cast + op_div) * (u1_tag * 1 + u8_tag * 8 + u16_tag * 16 + u32_tag * 32 + u64_tag * 64 + u128_tag * 128) +
(op_shl + op_shr) * (MAX_BITS - ib) * NON_TRIVIAL_SHIFT;

// Permutation to the Range Check Gadget
Expand Down Expand Up @@ -92,15 +101,16 @@ namespace alu(256);
// These are useful and commonly used relations / columns used through the file

// The maximum number of bits as defined by the instr tag
pol MAX_BITS = u8_tag * 8 + u16_tag * 16 + u32_tag * 32 + u64_tag * 64 + u128_tag * 128;
pol MAX_BITS = u1_tag * 1 + u8_tag * 8 + u16_tag * 16 + u32_tag * 32 + u64_tag * 64 + u128_tag * 128;
// 2^MAX_BITS
pol MAX_BITS_POW = u8_tag * 2**8 + u16_tag * 2**16 + u32_tag * 2**32 + u64_tag * 2**64 + u128_tag * 2**128;
pol MAX_BITS_POW = u1_tag * 2 + u8_tag * 2**8 + u16_tag * 2**16 + u32_tag * 2**32 + u64_tag * 2**64 + u128_tag * 2**128;
pol UINT_MAX = MAX_BITS_POW - 1;

// Value of p - 1
pol MAX_FIELD_VALUE = 21888242871839275222246405745257275088548364400416034343698204186575808495616;

// Used when we split inputs into lo and hi limbs each of (MAX_BITS / 2)
// omitted: u1_tag * 0 (no need for limbs...)
pol LIMB_BITS_POW = u8_tag * 2**4 + u16_tag * 2**8 + u32_tag * 2**16 + u64_tag * 2**32 + u128_tag * 2**64;
// Lo and Hi Limbs for ia, ib and ic resp. Useful when performing operations over integers
pol commit a_lo;
Expand Down Expand Up @@ -132,7 +142,8 @@ namespace alu(256);
a_lo * b_hi + b_lo * a_hi = partial_prod_lo + LIMB_BITS_POW * partial_prod_hi;

// This holds the product over the integers
pol PRODUCT = a_lo * b_lo + LIMB_BITS_POW * partial_prod_lo + MAX_BITS_POW * (partial_prod_hi + a_hi * b_hi);
// (u1 multiplication only cares about a_lo and b_lo)
pol PRODUCT = a_lo * b_lo + (1 - u1_tag) * (LIMB_BITS_POW * partial_prod_lo + MAX_BITS_POW * (partial_prod_hi + a_hi * b_hi));

// =============== ADDITION/SUBTRACTION Operation Constraints =================================================
pol commit op_add;
Expand Down Expand Up @@ -243,6 +254,7 @@ namespace alu(256);

// =============== Trivial Shift Operation =================================================
// We use the comparison gadget to test ib > (MAX_BITS - 1)
// (always true for u1 - all u1 shifts are trivial)
(op_shl + op_shr) * (cmp_gadget_input_a - ib) = 0;
(op_shl + op_shr) * (cmp_gadget_input_b - (MAX_BITS - 1) ) = 0;

Expand Down
4 changes: 2 additions & 2 deletions barretenberg/cpp/pil/avm/binary.pil
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace binary(256);
pol commit acc_ib;
pol commit acc_ic;

// This is the instruction tag {1,2,3,4,5} (restricted to not be a field)
// This is the instruction tag {1,2,3,4,5,6} (restricted to not be a field)
// Operations over FF are not supported, it is assumed this exclusion is handled
// outside of this subtrace.

Expand All @@ -37,7 +37,7 @@ namespace binary(256);

// To support dynamically sized memory operands we use a counter against a lookup
// This decrementing counter goes from [MEM_TAG, 0] where MEM_TAG is the number of bytes in the
// corresponding integer. i.e. MEM_TAG is between 1 (U8) and 16(U128).
// corresponding integer. i.e. MEM_TAG is between 1 (U8) and 16 (U128).
// Consistency can be achieved with a lookup table between the instr_tag and bytes_length
pol commit mem_tag_ctr;
#[MEM_TAG_REL]
Expand Down
7 changes: 7 additions & 0 deletions barretenberg/cpp/pil/avm/constants_gen.pil
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ namespace constants(256);
pol MAX_NULLIFIER_NON_EXISTENT_READ_REQUESTS_PER_CALL = 16;
pol MAX_L1_TO_L2_MSG_READ_REQUESTS_PER_CALL = 16;
pol MAX_UNENCRYPTED_LOGS_PER_CALL = 4;
pol MEM_TAG_U1 = 1;
pol MEM_TAG_U8 = 2;
pol MEM_TAG_U16 = 3;
pol MEM_TAG_U32 = 4;
pol MEM_TAG_U64 = 5;
pol MEM_TAG_U128 = 6;
pol MEM_TAG_FF = 7;
pol SENDER_SELECTOR = 0;
pol ADDRESS_SELECTOR = 1;
pol STORAGE_ADDRESS_SELECTOR = 1;
Expand Down
6 changes: 3 additions & 3 deletions barretenberg/cpp/pil/avm/fixed/byte_lookup.pil
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ namespace byte_lookup(256);
pol constant sel_bin;

// These two columns are a mapping between instruction tags and their byte lengths
// {U8: 1, U16: 2, ... , U128: 16}
pol constant table_in_tags; // Column of U8,U16,...,U128
pol constant table_byte_lengths; // Columns of byte lengths 1,2,...,16;
// {U1: 1, U8: 1, U16: 2, ... , U128: 16}
pol constant table_in_tags; // Column of U1,U8,U16,...,U128
pol constant table_byte_lengths; // Columns of byte lengths 1,1,2,...,16;
1 change: 1 addition & 0 deletions barretenberg/cpp/pil/avm/gadgets/conversion.pil
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ namespace conversion(256);
pol commit input;
pol commit radix;
pol commit num_limbs;
pol commit output_bits;
22 changes: 11 additions & 11 deletions barretenberg/cpp/pil/avm/main.pil
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ namespace main(256);
// Helper selector to characterize a Binary chiplet selector
pol commit sel_bin;

// Instruction memory tags read/write (1: u8, 2: u16, 3: u32, 4: u64, 5: u128, 6: field)
// Instruction memory tags read/write (1: u1, 2: u8, 3: u16, 4: u32, 5: u64, 6: u128, 7: field)
pol commit r_in_tag;
pol commit w_in_tag;
pol commit alu_in_tag; // Copy of r_in_tag or w_in_tag depending of the operation. It is sent to ALU trace.
Expand Down Expand Up @@ -311,17 +311,17 @@ namespace main(256);
// values should be written into these memory indices.
// - For indirect memory accesses, the memory trace constraints ensure that
// loaded values come from memory addresses with tag u32. This is enforced in the memory trace
// where each memory entry with flag sel_resolve_ind_addr_x (for x = a,b,c,d) constrains r_int_tag == 3 (u32).
// where each memory entry with flag sel_resolve_ind_addr_x (for x = a,b,c,d) constrains r_int_tag is u32.
//
// - ind_addr_a, ind_addr_b, ind_addr_c, ind_addr_d to u32 type: Should be guaranteed by bytecode validation and
// instruction decomposition as only immediate 32-bit values should be written into the indirect registers.
//
// - 0 <= r_in_tag, w_in_tag <= 6 // This should be constrained by the operation decomposition.
// - 0 <= r_in_tag, w_in_tag <= constants.MEM_TAG_FF // This should be constrained by the operation decomposition.

//====== COMPARATOR OPCODES CONSTRAINTS =====================================
// Enforce that the tag for the ouput of EQ opcode is u8 (i.e. equal to 1).
#[OUTPUT_U8]
(sel_op_eq + sel_op_lte + sel_op_lt) * (w_in_tag - 1) = 0;
// Enforce that the tag for the ouput of EQ opcode is u1 (i.e. equal to 1).
#[OUTPUT_U1]
(sel_op_eq + sel_op_lte + sel_op_lt) * (w_in_tag - constants.MEM_TAG_U1) = 0;

//====== FDIV OPCODE CONSTRAINTS ============================================
// Relation for division over the finite field
Expand All @@ -340,13 +340,13 @@ namespace main(256);
#[SUBOP_FDIV_ZERO_ERR2]
(sel_op_fdiv + sel_op_div) * op_err * (1 - inv) = 0;

// Enforcement that instruction tags are FF (tag constant 6).
// Enforcement that instruction tags are FF
// TODO: These 2 conditions might be removed and enforced through
// the bytecode decomposition instead.
#[SUBOP_FDIV_R_IN_TAG_FF]
sel_op_fdiv * (r_in_tag - 6) = 0;
sel_op_fdiv * (r_in_tag - constants.MEM_TAG_FF) = 0;
#[SUBOP_FDIV_W_IN_TAG_FF]
sel_op_fdiv * (w_in_tag - 6) = 0;
sel_op_fdiv * (w_in_tag - constants.MEM_TAG_FF) = 0;

// op_err cannot be maliciously activated for a non-relevant
// operation selector, i.e., op_err == 1 ==> sel_op_fdiv || sel_op_XXX || ...
Expand Down Expand Up @@ -542,9 +542,9 @@ namespace main(256);
binary.start {binary.clk, binary.acc_ia, binary.acc_ib, binary.acc_ic, binary.op_id, binary.in_tag};

#[PERM_MAIN_CONV]
sel_op_radix_le {clk, ia, ic, id}
sel_op_radix_le {clk, ia, ib, ic, id}
is
conversion.sel_to_radix_le {conversion.clk, conversion.input, conversion.radix, conversion.num_limbs};
conversion.sel_to_radix_le {conversion.clk, conversion.input, conversion.radix, conversion.num_limbs, conversion.output_bits};

// This will be enabled when we migrate just to sha256Compression, as getting sha256 to work with it is tricky.
// #[PERM_MAIN_SHA256]
Expand Down
Loading

0 comments on commit 62c53fc

Please sign in to comment.