Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: decode KZG points directly into the buffers #840

Merged
merged 1 commit into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 10 additions & 37 deletions bins/revme/src/cmd/format_kzg_setup.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub use revm::primitives::kzg::{format_kzg_settings, G1Points, G2Points, KzgErrors};
pub use revm::primitives::kzg::{parse_kzg_trusted_setup, G1Points, G2Points, KzgErrors};
use std::path::PathBuf;
use std::{env, fs};
use structopt::StructOpt;
Expand Down Expand Up @@ -31,7 +31,7 @@ impl Cmd {
fs::read_to_string(&self.path).map_err(|_| KzgErrors::NotValidFile)?;

// format points
let (g1, g2) = format_kzg_settings(&kzg_trusted_settings)?;
let (g1, g2) = parse_kzg_trusted_setup(&kzg_trusted_settings)?;

let g1_path = self
.g1
Expand All @@ -44,44 +44,17 @@ impl Cmd {
.unwrap_or_else(|| out_dir.join("g2_points.bin"));

// output points
fs::write(&g1_path, into_flattened(g1.to_vec())).map_err(|_| KzgErrors::IOError)?;
fs::write(&g2_path, into_flattened(g2.to_vec())).map_err(|_| KzgErrors::IOError)?;
fs::write(&g1_path, flatten(&g1.0)).map_err(|_| KzgErrors::IOError)?;
fs::write(&g2_path, flatten(&g2.0)).map_err(|_| KzgErrors::IOError)?;
println!("Finished formatting kzg trusted setup into binary representation.");
println!("G1 point path: {:?}", g1_path);
println!("G2 point path: {:?}", g2_path);
println!("G1 points path: {:?}", g1_path);
println!("G2 points path: {:?}", g2_path);
Ok(())
}
}

/// [`Vec::into_flattened`].
fn into_flattened<T, const N: usize>(vec: Vec<[T; N]>) -> Vec<T> {
let (ptr, len, cap) = into_raw_parts(vec);
let (new_len, new_cap) = if core::mem::size_of::<T>() == 0 {
(len.checked_mul(N).expect("vec len overflow"), usize::MAX)
} else {
// SAFETY:
// - `cap * N` cannot overflow because the allocation is already in
// the address space.
// - Each `[T; N]` has `N` valid elements, so there are `len * N`
// valid elements in the allocation.
unsafe {
(
len.checked_mul(N).unwrap_unchecked(),
cap.checked_mul(N).unwrap_unchecked(),
)
}
};
// SAFETY:
// - `ptr` was allocated by `self`
// - `ptr` is well-aligned because `[T; N]` has the same alignment as `T`.
// - `new_cap` refers to the same sized allocation as `cap` because
// `new_cap * size_of::<T>()` == `cap * size_of::<[T; N]>()`
// - `len` <= `cap`, so `len * N` <= `cap * N`.
unsafe { Vec::from_raw_parts(ptr.cast(), new_len, new_cap) }
}

/// [`Vec::into_raw_parts`]
fn into_raw_parts<T>(vec: Vec<T>) -> (*mut T, usize, usize) {
let mut me = core::mem::ManuallyDrop::new(vec);
(me.as_mut_ptr(), me.len(), me.capacity())
fn flatten<const N: usize, const M: usize>(x: &[[u8; N]; M]) -> &[u8] {
// SAFETY: `x` is a valid `[[u8; N]; M]` and `N * M` is the length of the
// returned slice.
unsafe { core::slice::from_raw_parts(x.as_ptr().cast(), N * M) }
}
2 changes: 1 addition & 1 deletion crates/primitives/src/kzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ mod trusted_setup_points;
pub use c_kzg::KzgSettings;
pub use env_settings::EnvKzgSettings;
pub use trusted_setup_points::{
format_kzg_settings, G1Points, G2Points, KzgErrors, BYTES_PER_G1_POINT, BYTES_PER_G2_POINT,
parse_kzg_trusted_setup, G1Points, G2Points, KzgErrors, BYTES_PER_G1_POINT, BYTES_PER_G2_POINT,
G1_POINTS, G2_POINTS, NUM_G1_POINTS, NUM_G2_POINTS,
};
43 changes: 22 additions & 21 deletions crates/primitives/src/kzg/trusted_setup_points.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@ pub const NUM_G2_POINTS: usize = 65;

/// A newtype over list of G1 point from kzg trusted setup.
#[derive(Debug, Clone, PartialEq, AsRef, AsMut, Deref, DerefMut)]
#[repr(transparent)]
pub struct G1Points(pub [[u8; BYTES_PER_G1_POINT]; NUM_G1_POINTS]);

/// A newtype over list of G2 point from kzg trusted setup.
#[derive(Debug, Clone, Eq, PartialEq, AsRef, AsMut, Deref, DerefMut)]
pub struct G2Points(pub [[u8; BYTES_PER_G2_POINT]; NUM_G2_POINTS]);

impl Default for G1Points {
fn default() -> Self {
Self([[0; BYTES_PER_G1_POINT]; NUM_G1_POINTS])
}
}

/// A newtype over list of G2 point from kzg trusted setup.
#[derive(Debug, Clone, Eq, PartialEq, AsRef, AsMut, Deref, DerefMut)]
#[repr(transparent)]
pub struct G2Points(pub [[u8; BYTES_PER_G2_POINT]; NUM_G2_POINTS]);

impl Default for G2Points {
fn default() -> Self {
Self([[0; BYTES_PER_G2_POINT]; NUM_G2_POINTS])
Expand All @@ -42,16 +45,14 @@ pub const G2_POINTS: &G2Points = {
unsafe { &*BYTES.as_ptr().cast::<G2Points>() }
};

/// Pros over `include_str!(<path-to-trusted-setup>)`:
/// - partially decoded (hex strings -> point bytes)
/// - smaller runtime static size (198K = `4096*48 + 65*96` vs 404K)
/// - don't have to do weird hacks to call `load_trusted_setup_file` at runtime, see
/// [Reth](https://github.com/paradigmxyz/reth/blob/b839e394a45edbe7b2030fb370420ca771e5b728/crates/primitives/src/constants/eip4844.rs#L44-L52)
pub fn format_kzg_settings(
/// Parses the contents of a KZG trusted setup file into a list of G1 and G2 points.
///
/// These can then be used to create a KZG settings object with
/// [`KzgSettings::load_trusted_setup`](c_kzg::KzgSettings::load_trusted_setup).
pub fn parse_kzg_trusted_setup(
trusted_setup: &str,
) -> Result<(Box<G1Points>, Box<G2Points>), KzgErrors> {
let contents = trusted_setup;
let mut lines = contents.lines();
let mut lines = trusted_setup.lines();

// load number of points
let n_g1 = lines
Expand All @@ -65,26 +66,26 @@ pub fn format_kzg_settings(
.parse::<usize>()
.map_err(|_| KzgErrors::ParseError)?;

if n_g2 != 65 {
if n_g1 != NUM_G1_POINTS {
return Err(KzgErrors::MismatchedNumberOfPoints);
}

if n_g2 != NUM_G2_POINTS {
return Err(KzgErrors::MismatchedNumberOfPoints);
}

// load g1 points
let mut g1_points = Box::<G1Points>::default();
for i in 0..n_g1 {
for bytes in &mut g1_points.0 {
let line = lines.next().ok_or(KzgErrors::FileFormatError)?;
let mut bytes = [0; BYTES_PER_G1_POINT];
crate::hex::decode_to_slice(line, &mut bytes).map_err(|_| KzgErrors::ParseError)?;
g1_points[i] = bytes;
crate::hex::decode_to_slice(line, bytes).map_err(|_| KzgErrors::ParseError)?;
}

// load g2 points
let mut g2_points = Box::<G2Points>::default();
for i in 0..n_g2 {
for bytes in &mut g2_points.0 {
let line = lines.next().ok_or(KzgErrors::FileFormatError)?;
let mut bytes = [0; BYTES_PER_G2_POINT];
crate::hex::decode_to_slice(line, &mut bytes).map_err(|_| KzgErrors::ParseError)?;
g2_points[i] = bytes;
crate::hex::decode_to_slice(line, bytes).map_err(|_| KzgErrors::ParseError)?;
}

if lines.next().is_some() {
Expand Down
Loading