From b3b372c8ca2216fa32337d67954ccf5ed02bc03a Mon Sep 17 00:00:00 2001 From: Tej Qu Nair Date: Wed, 28 Aug 2024 11:13:05 -0700 Subject: [PATCH] perf: basic constant propagation for `Imm` instructions in v2 compiler (#1421) --- .../compiler/src/circuit/compiler.rs | 239 +++++++++--------- .../core-v2/src/chips/mem/constant.rs | 2 +- 2 files changed, 125 insertions(+), 116 deletions(-) diff --git a/crates/recursion/compiler/src/circuit/compiler.rs b/crates/recursion/compiler/src/circuit/compiler.rs index 1d41dcb9af..c83f973d84 100644 --- a/crates/recursion/compiler/src/circuit/compiler.rs +++ b/crates/recursion/compiler/src/circuit/compiler.rs @@ -27,7 +27,7 @@ pub struct AsmCompiler { /// Map the frame pointers of the variables to the "physical" addresses. pub virtual_to_physical: VecMap>, /// Map base or extension field constants to "physical" addresses and mults. - pub consts: BTreeMap, (Address, C::F)>, + pub consts: BTreeMap, Address>, /// Map each "physical" address to its read count. pub addr_to_mult: VecMap, } @@ -38,12 +38,11 @@ where { /// Allocate a fresh address. Checks that the address space is not full. pub fn alloc(next_addr: &mut C::F) -> Address { - let id = Address(*next_addr); *next_addr += C::F::one(); if next_addr.is_zero() { panic!("out of address space"); } - id + Address(*next_addr) } /// Map `fp` to its existing address without changing its mult. @@ -108,7 +107,7 @@ where /// /// Ensures that `addr` has already been assigned a `mult`. pub fn read_ghost_addr(&mut self, addr: Address) -> &mut C::F { - self.read_addr_internal(addr, true) + self.read_addr_internal(addr, false) } fn read_addr_internal(&mut self, addr: Address, increment_mult: bool) -> &mut C::F { @@ -143,27 +142,36 @@ where /// /// Increments the mult, first creating an entry if it does not yet exist. pub fn read_const(&mut self, imm: Imm) -> Address { - self.consts - .entry(imm) - .and_modify(|(_, x)| *x += C::F::one()) - .or_insert_with(|| (Self::alloc(&mut self.next_addr), C::F::one())) - .0 + use vec_map::Entry; + let addr = self.read_ghost_const(imm); + match self.addr_to_mult.entry(addr.as_usize()) { + Entry::Vacant(entry) => drop(entry.insert(C::F::one())), + Entry::Occupied(mut entry) => *entry.get_mut() += C::F::one(), + } + addr } /// Read a constant (a.k.a. immediate). /// /// Does not increment the mult. Creates an entry if it does not yet exist. pub fn read_ghost_const(&mut self, imm: Imm) -> Address { - self.consts.entry(imm).or_insert_with(|| (Self::alloc(&mut self.next_addr), C::F::zero())).0 + let addr = *self.consts.entry(imm).or_insert_with(|| Self::alloc(&mut self.next_addr)); + self.addr_to_mult.entry(addr.as_usize()).or_insert_with(C::F::zero); + addr } - fn mem_write_const(&mut self, dst: impl Reg, src: Imm) -> Instruction { - Instruction::Mem(MemInstr { - addrs: MemIo { inner: dst.write(self) }, - vals: MemIo { inner: src.as_block() }, - mult: C::F::zero(), - kind: MemAccessKind::Write, - }) + /// Turn `dst` into an alias for the constant `src`. + fn mem_write_const(&mut self, dst: impl HasVirtualAddress, src: Imm) { + use vec_map::Entry; + let src_addr = src.read_ghost(self); + match self.virtual_to_physical.entry(dst.vaddr()) { + Entry::Vacant(entry) => drop(entry.insert(src_addr)), + Entry::Occupied(entry) => panic!( + "unexpected entry: virtual_to_physical[{:?}] = {:?}", + dst.vaddr(), + entry.get() + ), + } } fn base_alu( @@ -203,7 +211,7 @@ where use BaseAluOpcode::*; let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr)); f(self.base_alu(SubF, diff, lhs, rhs)); - f(self.base_alu(DivF, out, diff, Imm::F(C::F::zero()))); + f(self.base_alu(DivF, out, diff, Imm::f(C::F::zero()))); } fn base_assert_ne( @@ -216,7 +224,7 @@ where let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr)); f(self.base_alu(SubF, diff, lhs, rhs)); - f(self.base_alu(DivF, out, Imm::F(C::F::one()), diff)); + f(self.base_alu(DivF, out, Imm::f(C::F::one()), diff)); } fn ext_assert_eq( @@ -229,7 +237,7 @@ where let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr)); f(self.ext_alu(SubE, diff, lhs, rhs)); - f(self.ext_alu(DivE, out, diff, Imm::EF(C::EF::zero()))); + f(self.ext_alu(DivE, out, diff, Imm::ef(C::EF::zero()))); } fn ext_assert_ne( @@ -242,7 +250,7 @@ where let [diff, out] = core::array::from_fn(|_| Self::alloc(&mut self.next_addr)); f(self.ext_alu(SubE, diff, lhs, rhs)); - f(self.ext_alu(DivE, out, Imm::EF(C::EF::one()), diff)); + f(self.ext_alu(DivE, out, Imm::ef(C::EF::one()), diff)); } fn poseidon2_permute( @@ -384,71 +392,71 @@ where let mut f = |instr| consumer(Ok(instr)); match ir_instr { - DslIr::ImmV(dst, src) => f(self.mem_write_const(dst, Imm::F(src))), - DslIr::ImmF(dst, src) => f(self.mem_write_const(dst, Imm::F(src))), - DslIr::ImmE(dst, src) => f(self.mem_write_const(dst, Imm::EF(src))), + DslIr::ImmV(dst, src) => self.mem_write_const(dst, Imm::f(src)), + DslIr::ImmF(dst, src) => self.mem_write_const(dst, Imm::f(src)), + DslIr::ImmE(dst, src) => self.mem_write_const(dst, Imm::ef(src)), DslIr::AddV(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, rhs)), - DslIr::AddVI(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, Imm::F(rhs))), + DslIr::AddVI(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, Imm::f(rhs))), DslIr::AddF(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, rhs)), - DslIr::AddFI(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, Imm::F(rhs))), + DslIr::AddFI(dst, lhs, rhs) => f(self.base_alu(AddF, dst, lhs, Imm::f(rhs))), DslIr::AddE(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, rhs)), - DslIr::AddEI(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, Imm::EF(rhs))), + DslIr::AddEI(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, Imm::ef(rhs))), DslIr::AddEF(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, rhs)), - DslIr::AddEFI(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, Imm::F(rhs))), - DslIr::AddEFFI(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, Imm::EF(rhs))), + DslIr::AddEFI(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, Imm::f(rhs))), + DslIr::AddEFFI(dst, lhs, rhs) => f(self.ext_alu(AddE, dst, lhs, Imm::ef(rhs))), DslIr::SubV(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, rhs)), - DslIr::SubVI(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, Imm::F(rhs))), - DslIr::SubVIN(dst, lhs, rhs) => f(self.base_alu(SubF, dst, Imm::F(lhs), rhs)), + DslIr::SubVI(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, Imm::f(rhs))), + DslIr::SubVIN(dst, lhs, rhs) => f(self.base_alu(SubF, dst, Imm::f(lhs), rhs)), DslIr::SubF(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, rhs)), - DslIr::SubFI(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, Imm::F(rhs))), - DslIr::SubFIN(dst, lhs, rhs) => f(self.base_alu(SubF, dst, Imm::F(lhs), rhs)), + DslIr::SubFI(dst, lhs, rhs) => f(self.base_alu(SubF, dst, lhs, Imm::f(rhs))), + DslIr::SubFIN(dst, lhs, rhs) => f(self.base_alu(SubF, dst, Imm::f(lhs), rhs)), DslIr::SubE(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, rhs)), - DslIr::SubEI(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, Imm::EF(rhs))), - DslIr::SubEIN(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, Imm::EF(lhs), rhs)), - DslIr::SubEFI(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, Imm::F(rhs))), + DslIr::SubEI(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, Imm::ef(rhs))), + DslIr::SubEIN(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, Imm::ef(lhs), rhs)), + DslIr::SubEFI(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, Imm::f(rhs))), DslIr::SubEF(dst, lhs, rhs) => f(self.ext_alu(SubE, dst, lhs, rhs)), DslIr::MulV(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, rhs)), - DslIr::MulVI(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, Imm::F(rhs))), + DslIr::MulVI(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, Imm::f(rhs))), DslIr::MulF(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, rhs)), - DslIr::MulFI(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, Imm::F(rhs))), + DslIr::MulFI(dst, lhs, rhs) => f(self.base_alu(MulF, dst, lhs, Imm::f(rhs))), DslIr::MulE(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, rhs)), - DslIr::MulEI(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, Imm::EF(rhs))), - DslIr::MulEFI(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, Imm::F(rhs))), + DslIr::MulEI(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, Imm::ef(rhs))), + DslIr::MulEFI(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, Imm::f(rhs))), DslIr::MulEF(dst, lhs, rhs) => f(self.ext_alu(MulE, dst, lhs, rhs)), DslIr::DivF(dst, lhs, rhs) => f(self.base_alu(DivF, dst, lhs, rhs)), - DslIr::DivFI(dst, lhs, rhs) => f(self.base_alu(DivF, dst, lhs, Imm::F(rhs))), - DslIr::DivFIN(dst, lhs, rhs) => f(self.base_alu(DivF, dst, Imm::F(lhs), rhs)), + DslIr::DivFI(dst, lhs, rhs) => f(self.base_alu(DivF, dst, lhs, Imm::f(rhs))), + DslIr::DivFIN(dst, lhs, rhs) => f(self.base_alu(DivF, dst, Imm::f(lhs), rhs)), DslIr::DivE(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, rhs)), - DslIr::DivEI(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, Imm::EF(rhs))), - DslIr::DivEIN(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, Imm::EF(lhs), rhs)), - DslIr::DivEFI(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, Imm::F(rhs))), - DslIr::DivEFIN(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, Imm::F(lhs), rhs)), + DslIr::DivEI(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, Imm::ef(rhs))), + DslIr::DivEIN(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, Imm::ef(lhs), rhs)), + DslIr::DivEFI(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, Imm::f(rhs))), + DslIr::DivEFIN(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, Imm::f(lhs), rhs)), DslIr::DivEF(dst, lhs, rhs) => f(self.ext_alu(DivE, dst, lhs, rhs)), - DslIr::NegV(dst, src) => f(self.base_alu(SubF, dst, Imm::F(C::F::zero()), src)), - DslIr::NegF(dst, src) => f(self.base_alu(SubF, dst, Imm::F(C::F::zero()), src)), - DslIr::NegE(dst, src) => f(self.ext_alu(SubE, dst, Imm::EF(C::EF::zero()), src)), - DslIr::InvV(dst, src) => f(self.base_alu(DivF, dst, Imm::F(C::F::one()), src)), - DslIr::InvF(dst, src) => f(self.base_alu(DivF, dst, Imm::F(C::F::one()), src)), - DslIr::InvE(dst, src) => f(self.ext_alu(DivE, dst, Imm::F(C::F::one()), src)), + DslIr::NegV(dst, src) => f(self.base_alu(SubF, dst, Imm::f(C::F::zero()), src)), + DslIr::NegF(dst, src) => f(self.base_alu(SubF, dst, Imm::f(C::F::zero()), src)), + DslIr::NegE(dst, src) => f(self.ext_alu(SubE, dst, Imm::ef(C::EF::zero()), src)), + DslIr::InvV(dst, src) => f(self.base_alu(DivF, dst, Imm::f(C::F::one()), src)), + DslIr::InvF(dst, src) => f(self.base_alu(DivF, dst, Imm::f(C::F::one()), src)), + DslIr::InvE(dst, src) => f(self.ext_alu(DivE, dst, Imm::f(C::F::one()), src)), DslIr::AssertEqV(lhs, rhs) => self.base_assert_eq(lhs, rhs, f), DslIr::AssertEqF(lhs, rhs) => self.base_assert_eq(lhs, rhs, f), DslIr::AssertEqE(lhs, rhs) => self.ext_assert_eq(lhs, rhs, f), - DslIr::AssertEqVI(lhs, rhs) => self.base_assert_eq(lhs, Imm::F(rhs), f), - DslIr::AssertEqFI(lhs, rhs) => self.base_assert_eq(lhs, Imm::F(rhs), f), - DslIr::AssertEqEI(lhs, rhs) => self.ext_assert_eq(lhs, Imm::EF(rhs), f), + DslIr::AssertEqVI(lhs, rhs) => self.base_assert_eq(lhs, Imm::f(rhs), f), + DslIr::AssertEqFI(lhs, rhs) => self.base_assert_eq(lhs, Imm::f(rhs), f), + DslIr::AssertEqEI(lhs, rhs) => self.ext_assert_eq(lhs, Imm::ef(rhs), f), DslIr::AssertNeV(lhs, rhs) => self.base_assert_ne(lhs, rhs, f), DslIr::AssertNeF(lhs, rhs) => self.base_assert_ne(lhs, rhs, f), DslIr::AssertNeE(lhs, rhs) => self.ext_assert_ne(lhs, rhs, f), - DslIr::AssertNeVI(lhs, rhs) => self.base_assert_ne(lhs, Imm::F(rhs), f), - DslIr::AssertNeFI(lhs, rhs) => self.base_assert_ne(lhs, Imm::F(rhs), f), - DslIr::AssertNeEI(lhs, rhs) => self.ext_assert_ne(lhs, Imm::EF(rhs), f), + DslIr::AssertNeVI(lhs, rhs) => self.base_assert_ne(lhs, Imm::f(rhs), f), + DslIr::AssertNeFI(lhs, rhs) => self.base_assert_ne(lhs, Imm::f(rhs), f), + DslIr::AssertNeEI(lhs, rhs) => self.ext_assert_ne(lhs, Imm::ef(rhs), f), DslIr::CircuitV2Poseidon2PermuteBabyBear(data) => { f(self.poseidon2_permute(data.0, data.1)) @@ -536,7 +544,7 @@ where // Replace the mults using the address count data gathered in this previous. // Exhaustive match for refactoring purposes. - let total_memory = self.addr_to_mult.len() + self.consts.len(); + let total_memory = self.addr_to_mult.len(); let mut backfill = |(mult, addr): (&mut F, &Address)| { *mult = self.addr_to_mult.remove(addr.as_usize()).unwrap() }; @@ -553,12 +561,6 @@ where addrs: ExtAluIo { out: ref addr, .. }, .. }) => backfill((mult, addr)), - Instruction::Mem(MemInstr { - addrs: MemIo { inner: ref addr }, - mult, - kind: MemAccessKind::Write, - .. - }) => backfill((mult, addr)), Instruction::Poseidon2(instr) => { let Poseidon2SkinnyInstr { addrs: Poseidon2Io { output: ref addrs, .. }, @@ -595,21 +597,21 @@ where .iter_mut() .for_each(|(addr, mult)| backfill((mult, addr))); } + Instruction::Mem(_) => { + panic!("mem instructions should be produced through the `consts` map") + } // Instructions that do not write to memory. - Instruction::Mem(MemInstr { kind: MemAccessKind::Read, .. }) - | Instruction::CommitPublicValues(_) - | Instruction::Print(_) => (), + Instruction::CommitPublicValues(_) | Instruction::Print(_) => (), } } }); - debug_assert!(self.addr_to_mult.is_empty()); // Initialize constants. let total_consts = self.consts.len(); - let instrs_consts = take(&mut self.consts).into_iter().map(|(imm, (addr, mult))| { + let instrs_consts = take(&mut self.consts).into_iter().map(|(imm, addr)| { Instruction::Mem(MemInstr { addrs: MemIo { inner: addr }, vals: MemIo { inner: imm.as_block() }, - mult, + mult: self.addr_to_mult.remove(addr.as_usize()).unwrap(), kind: MemAccessKind::Write, }) }); @@ -627,6 +629,7 @@ where (instrs_consts.chain(instrs).collect(), traces) } }); + debug_assert!(self.addr_to_mult.is_empty()); RecursionProgram { instructions, total_memory, traces } } } @@ -667,6 +670,27 @@ pub enum Imm { EF(EF), } +impl Imm +where + F: Field, + EF: AbstractExtensionField, +{ + /// Wraps its argument in `Self::F`. + pub fn f(f: F) -> Self { + Self::F(f) + } + + /// If `ef` lives in the base field, then we encode it as `Self::F`. + /// Otherwise, we encode it as `Self::EF`. + pub fn ef(ef: EF) -> Self { + if ef.as_base_slice()[1..].iter().all(Field::is_zero) { + Self::F(ef.as_base_slice()[0]) + } else { + Self::EF(ef) + } + } +} + impl PartialOrd for Imm where F: PartialEq + AbstractField + PartialOrd, @@ -711,6 +735,25 @@ where } } +/// Expose the "virtual address" counter of the variable types. +trait HasVirtualAddress { + fn vaddr(&self) -> usize; +} + +macro_rules! impl_has_virtual_address { + ($type:ident<$($gen:ident),*>) => { + impl<$($gen),*> HasVirtualAddress for $type<$($gen),*> { + fn vaddr(&self) -> usize { + self.0 as usize + } + } + }; +} + +impl_has_virtual_address!(Var); +impl_has_virtual_address!(Felt); +impl_has_virtual_address!(Ext); + /// Utility functions for various register types. trait Reg { /// Mark the register as to be read from, returning the "physical" address. @@ -723,54 +766,20 @@ trait Reg { fn write(&self, compiler: &mut AsmCompiler) -> Address; } -macro_rules! impl_reg_borrowed { - ($a:ty) => { - impl Reg for $a - where - C: Config, - T: Reg + ?Sized, - { - fn read(&self, compiler: &mut AsmCompiler) -> Address { - (**self).read(compiler) - } - - fn read_ghost(&self, compiler: &mut AsmCompiler) -> Address { - (**self).read_ghost(compiler) - } - - fn write(&self, compiler: &mut AsmCompiler) -> Address { - (**self).write(compiler) - } - } - }; -} +impl, T: HasVirtualAddress> Reg for T { + fn read(&self, compiler: &mut AsmCompiler) -> Address { + compiler.read_vaddr(self.vaddr()) + } -// Allow for more flexibility in arguments. -impl_reg_borrowed!(&T); -impl_reg_borrowed!(&mut T); -impl_reg_borrowed!(Box); + fn read_ghost(&self, compiler: &mut AsmCompiler) -> Address { + compiler.read_ghost_vaddr(self.vaddr()) + } -macro_rules! impl_reg_vaddr { - ($a:ty) => { - impl> Reg for $a { - fn read(&self, compiler: &mut AsmCompiler) -> Address { - compiler.read_vaddr(self.0 as usize) - } - fn read_ghost(&self, compiler: &mut AsmCompiler) -> Address { - compiler.read_ghost_vaddr(self.0 as usize) - } - fn write(&self, compiler: &mut AsmCompiler) -> Address { - compiler.write_fp(self.0 as usize) - } - } - }; + fn write(&self, compiler: &mut AsmCompiler) -> Address { + compiler.write_fp(self.vaddr()) + } } -// These three types have `.fp()` but they don't share a trait. -impl_reg_vaddr!(Var); -impl_reg_vaddr!(Felt); -impl_reg_vaddr!(Ext); - impl> Reg for Imm { fn read(&self, compiler: &mut AsmCompiler) -> Address { compiler.read_const(*self) diff --git a/crates/recursion/core-v2/src/chips/mem/constant.rs b/crates/recursion/core-v2/src/chips/mem/constant.rs index 1b3654a1e0..681cbac8dd 100644 --- a/crates/recursion/core-v2/src/chips/mem/constant.rs +++ b/crates/recursion/core-v2/src/chips/mem/constant.rs @@ -12,7 +12,7 @@ use crate::{builder::SP1RecursionAirBuilder, *}; use super::MemoryAccessCols; -pub const NUM_MEM_ENTRIES_PER_ROW: usize = 16; +pub const NUM_MEM_ENTRIES_PER_ROW: usize = 6; #[derive(Default)] pub struct MemoryChip {