Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: De-couple Chips from a specific ExecutionRecord, part II #37

Draft
wants to merge 3 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 65 additions & 11 deletions core/src/air/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ use crate::{

/// A description of the events related to this AIR.
pub trait WithEvents<'a>: Sized {
/// output of a functional lens from the Record to
/// refs of those events relative to the AIR.
type Events: 'a;
/// the input events that this AIR needs to get a reference to in order to lay out its trace
type InputEvents: 'a;

// the output events that this AIR produces
type OutputEvents: 'a;
}

/// A trait intended for implementation on Records that may store events related to Chips,
Expand All @@ -23,7 +25,11 @@ pub trait WithEvents<'a>: Sized {
///
/// The name is inspired by (but not conformant to) functional optics ( https://doi.org/10.1145/1232420.1232424 )
pub trait EventLens<T: for<'b> WithEvents<'b>>: Indexed {
fn events(&self) -> <T as WithEvents<'_>>::Events;
fn events(&self) -> <T as WithEvents<'_>>::InputEvents;
}

pub trait EventMutLens<T: for<'b> WithEvents<'b>> {
fn add_events(&mut self, events: <T as WithEvents<'_>>::OutputEvents);
}

//////////////// Derive macro shenanigans ////////////////////////////////////////////////
Expand Down Expand Up @@ -63,10 +69,10 @@ where
R: EventLens<T>,
U: for<'b> WithEvents<'b>,
// see https://github.com/rust-lang/rust/issues/86702 for the empty parameter
F: for<'c> Fn(<T as WithEvents<'c>>::Events, &'c ()) -> <U as WithEvents<'c>>::Events,
F: for<'c> Fn(<T as WithEvents<'c>>::InputEvents, &'c ()) -> <U as WithEvents<'c>>::InputEvents,
{
fn events<'c>(&'c self) -> <U as WithEvents<'c>>::Events {
let events: <T as WithEvents<'c>>::Events = self.record.events();
fn events<'c>(&'c self) -> <U as WithEvents<'c>>::InputEvents {
let events: <T as WithEvents<'c>>::InputEvents = self.record.events();
(self.projection)(events, &())
}
}
Expand All @@ -80,12 +86,56 @@ where
self.record.index()
}
}

/// if I have an EventMutLens from T::Events, and a way (F) to deduce T::Events from U::Events,
/// I can compose them to get an EventMutLens from U::Events.
pub struct Inj<'a, T, R, F>
where
T: for<'b> WithEvents<'b>,
R: EventMutLens<T>,
{
record: &'a mut R,
injection: F,
_phantom: PhantomData<T>,
}

/// A constructor for the projection from T::Events to U::Events.
impl<'a, T, R, F> Inj<'a, T, R, F>
where
T: for<'b> WithEvents<'b>,
R: EventMutLens<T>,
{
pub fn new(record: &'a mut R, injection: F) -> Self {
Self {
record,
injection,
_phantom: PhantomData,
}
}
}

impl<'a, T, R, U, F> EventMutLens<U> for Inj<'a, T, R, F>
where
T: for<'b> WithEvents<'b>,
R: EventMutLens<T>,
U: for<'b> WithEvents<'b>,
// see https://github.com/rust-lang/rust/issues/86702 for the empty parameter
F: for<'c> Fn(
<U as WithEvents<'c>>::OutputEvents,
&'c (),
) -> <T as WithEvents<'c>>::OutputEvents,
{
fn add_events(&mut self, events: <U as WithEvents<'_>>::OutputEvents) {
let events: <T as WithEvents<'_>>::OutputEvents = (self.injection)(events, &());
self.record.add_events(events);
}
}
//////////////// end of shenanigans destined for the derive macros. ////////////////

/// An AIR that is part of a multi table AIR arithmetization.
pub trait MachineAir<F: Field>: BaseAir<F> + for<'a> WithEvents<'a> {
/// The execution record containing events for producing the air trace.
type Record: MachineRecord + EventLens<Self>;
type Record: MachineRecord + EventLens<Self> + EventMutLens<Self>;

type Program: MachineProgram<F>;

Expand All @@ -97,14 +147,18 @@ pub trait MachineAir<F: Field>: BaseAir<F> + for<'a> WithEvents<'a> {
/// - `input` is the execution record containing the events to be written to the trace.
/// - `output` is the execution record containing events that the `MachineAir` can add to
/// the record such as byte lookup requests.
fn generate_trace<EL: EventLens<Self>>(
fn generate_trace<EL: EventLens<Self>, OL: EventMutLens<Self>>(
&self,
input: &EL,
output: &mut Self::Record,
output: &mut OL,
) -> RowMajorMatrix<F>;

/// Generate the dependencies for a given execution record.
fn generate_dependencies<EL: EventLens<Self>>(&self, input: &EL, output: &mut Self::Record) {
fn generate_dependencies<EL: EventLens<Self>, OL: EventMutLens<Self>>(
&self,
input: &EL,
output: &mut OL,
) {
self.generate_trace(input, output);
}

Expand Down
17 changes: 9 additions & 8 deletions core/src/alu/add_sub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ use p3_maybe_rayon::prelude::ParallelSlice;
use sphinx_derive::AlignedBorrow;

use crate::{
air::{AluAirBuilder, EventLens, MachineAir, WithEvents, Word},
air::{AluAirBuilder, EventLens, EventMutLens, MachineAir, WithEvents, Word},
bytes::ByteLookupEvent,
operations::AddOperation,
runtime::{ExecutionRecord, Opcode, Program},
stark::MachineRecord,
utils::pad_to_power_of_two,
};

Expand Down Expand Up @@ -58,12 +58,13 @@ pub struct AddSubCols<T> {
}

impl<'a> WithEvents<'a> for AddSubChip {
type Events = (
type InputEvents = (
// add events
&'a [AluEvent],
// sub events
&'a [AluEvent],
);
type OutputEvents = &'a [ByteLookupEvent];
}

impl<F: PrimeField> MachineAir<F> for AddSubChip {
Expand All @@ -75,10 +76,10 @@ impl<F: PrimeField> MachineAir<F> for AddSubChip {
"AddSub".to_string()
}

fn generate_trace<EL: EventLens<Self>>(
fn generate_trace<EL: EventLens<Self>, OL: EventMutLens<Self>>(
&self,
input: &EL,
output: &mut Self::Record,
output: &mut OL,
) -> RowMajorMatrix<F> {
let (add_events, sub_events) = input.events();
// Generate the rows for the trace.
Expand All @@ -91,7 +92,7 @@ impl<F: PrimeField> MachineAir<F> for AddSubChip {
let rows_and_records = merged_events
.par_chunks(chunk_size)
.map(|events| {
let mut record = ExecutionRecord::default();
let mut record = Vec::new();
let rows = events
.iter()
.map(|event| {
Expand All @@ -117,9 +118,9 @@ impl<F: PrimeField> MachineAir<F> for AddSubChip {
.collect::<Vec<_>>();

let mut rows: Vec<[F; NUM_ADD_SUB_COLS]> = vec![];
for mut row_and_record in rows_and_records {
for row_and_record in rows_and_records {
rows.extend(row_and_record.0);
output.append(&mut row_and_record.1);
output.add_events(&row_and_record.1);
}

// Convert the trace to a row major matrix.
Expand Down
12 changes: 6 additions & 6 deletions core/src/alu/bitwise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use sphinx_derive::AlignedBorrow;

use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir};
use crate::air::{AluAirBuilder, ByteAirBuilder, EventMutLens, MachineAir};
use crate::air::{EventLens, WithEvents, Word};
use crate::bytes::event::ByteRecord;
use crate::bytes::{ByteLookupEvent, ByteOpcode};
use crate::runtime::{ExecutionRecord, Opcode, Program};
use crate::utils::pad_to_power_of_two;
Expand Down Expand Up @@ -52,7 +51,8 @@ pub struct BitwiseCols<T> {
}

impl<'a> WithEvents<'a> for BitwiseChip {
type Events = &'a [AluEvent];
type InputEvents = &'a [AluEvent];
type OutputEvents = &'a [ByteLookupEvent];
}

impl<F: PrimeField> MachineAir<F> for BitwiseChip {
Expand All @@ -64,10 +64,10 @@ impl<F: PrimeField> MachineAir<F> for BitwiseChip {
"Bitwise".to_string()
}

fn generate_trace<EL: EventLens<Self>>(
fn generate_trace<EL: EventLens<Self>, OL: EventMutLens<Self>>(
&self,
input: &EL,
output: &mut ExecutionRecord,
output: &mut OL,
) -> RowMajorMatrix<F> {
// Generate the trace rows for each event.
let rows = input
Expand Down Expand Up @@ -98,7 +98,7 @@ impl<F: PrimeField> MachineAir<F> for BitwiseChip {
b: u32::from(b_b),
c: u32::from(b_c),
};
output.add_byte_lookup_event(byte_event);
output.add_events(&[byte_event]);
}

row
Expand Down
40 changes: 29 additions & 11 deletions core/src/alu/divrem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ use p3_matrix::Matrix;
use sphinx_derive::AlignedBorrow;

use self::utils::eval_abs_value;
use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir, WordAirBuilder};
use crate::air::{AluAirBuilder, ByteAirBuilder, EventMutLens, MachineAir, WordAirBuilder};
use crate::air::{EventLens, WithEvents, Word};
use crate::alu::divrem::utils::{get_msb, get_quotient_and_remainder, is_signed_operation};
use crate::alu::AluEvent;
Expand Down Expand Up @@ -187,8 +187,15 @@ pub struct DivRemCols<T> {
pub is_real: T,
}

pub enum DivRemEvent<'a> {
ByteLookupEvent(&'a ByteLookupEvent),
MulEvent(&'a AluEvent),
LtEvent(&'a AluEvent),
}

impl<'a> WithEvents<'a> for DivRemChip {
type Events = &'a [AluEvent];
type InputEvents = &'a [AluEvent];
type OutputEvents = &'a [DivRemEvent<'a>];
}

impl<F: PrimeField> MachineAir<F> for DivRemChip {
Expand All @@ -200,11 +207,15 @@ impl<F: PrimeField> MachineAir<F> for DivRemChip {
"DivRem".to_string()
}

fn generate_trace<EL: EventLens<Self>>(
fn generate_trace<EL: EventLens<Self>, OL: EventMutLens<Self>>(
&self,
input: &EL,
output: &mut ExecutionRecord,
output: &mut OL,
) -> RowMajorMatrix<F> {
let mut byte_events = Vec::new();
let mut mul_events = Vec::new();
let mut lt_events = Vec::new();

// Generate the trace rows for each event.
let divrem_events = input.events();
let mut rows: Vec<[F; NUM_DIVREM_COLS]> = Vec::with_capacity(divrem_events.len());
Expand Down Expand Up @@ -273,7 +284,7 @@ impl<F: PrimeField> MachineAir<F> for DivRemChip {
c: 0,
});
}
output.add_byte_lookup_events(blu_events);
byte_events.extend(blu_events);
}
}

Expand Down Expand Up @@ -332,7 +343,7 @@ impl<F: PrimeField> MachineAir<F> for DivRemChip {
c: event.c,
b: quotient,
};
output.add_mul_event(lower_multiplication);
mul_events.push(lower_multiplication);

let upper_multiplication = AluEvent {
shard: event.shard,
Expand All @@ -349,7 +360,7 @@ impl<F: PrimeField> MachineAir<F> for DivRemChip {
b: quotient,
};

output.add_mul_event(upper_multiplication);
mul_events.push(upper_multiplication);

let lt_event = if is_signed_operation(event.opcode) {
AluEvent {
Expand All @@ -370,16 +381,23 @@ impl<F: PrimeField> MachineAir<F> for DivRemChip {
clk: event.clk,
}
};
output.add_lt_event(lt_event);
lt_events.push(lt_event);
}

// Range check.
{
output.add_u8_range_checks(event.shard, &quotient.to_le_bytes());
output.add_u8_range_checks(event.shard, &remainder.to_le_bytes());
output.add_u8_range_checks(event.shard, &c_times_quotient);
byte_events.add_u8_range_checks(event.shard, &quotient.to_le_bytes());
byte_events.add_u8_range_checks(event.shard, &remainder.to_le_bytes());
byte_events.add_u8_range_checks(event.shard, &c_times_quotient);
}
}
let events = byte_events
.iter()
.map(DivRemEvent::ByteLookupEvent)
.chain(mul_events.iter().map(DivRemEvent::MulEvent))
.chain(lt_events.iter().map(DivRemEvent::LtEvent))
.collect::<Vec<_>>();
output.add_events(&events);

rows.push(row);
}
Expand Down
11 changes: 6 additions & 5 deletions core/src/alu/lt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use p3_matrix::Matrix;
use p3_maybe_rayon::prelude::*;
use sphinx_derive::AlignedBorrow;

use crate::air::{AluAirBuilder, BaseAirBuilder, ByteAirBuilder, MachineAir};
use crate::air::{AluAirBuilder, BaseAirBuilder, ByteAirBuilder, EventMutLens, MachineAir};
use crate::air::{EventLens, WithEvents, Word};
use crate::bytes::event::ByteRecord;
use crate::bytes::{ByteLookupEvent, ByteOpcode};
Expand Down Expand Up @@ -94,7 +94,8 @@ impl LtCols<u32> {
}

impl<'a> WithEvents<'a> for LtChip {
type Events = &'a [AluEvent];
type InputEvents = &'a [AluEvent];
type OutputEvents = &'a [ByteLookupEvent];
}

impl<F: PrimeField32> MachineAir<F> for LtChip {
Expand All @@ -106,10 +107,10 @@ impl<F: PrimeField32> MachineAir<F> for LtChip {
"Lt".to_string()
}

fn generate_trace<EL: EventLens<Self>>(
fn generate_trace<EL: EventLens<Self>, OL: EventMutLens<Self>>(
&self,
input: &EL,
output: &mut ExecutionRecord,
output: &mut OL,
) -> RowMajorMatrix<F> {
// Generate the trace rows for each event.
let (rows, new_byte_lookup_events): (Vec<_>, Vec<_>) = input
Expand Down Expand Up @@ -211,7 +212,7 @@ impl<F: PrimeField32> MachineAir<F> for LtChip {
.unzip();

for byte_lookup_events in new_byte_lookup_events {
output.add_byte_lookup_events(byte_lookup_events);
output.add_events(&byte_lookup_events);
}

// Convert the trace to a row major matrix.
Expand Down
Loading
Loading