Skip to content

Commit

Permalink
feat: add CrossJoinExec
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Dec 3, 2024
1 parent 16ed809 commit a9fce18
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 1 deletion.
129 changes: 129 additions & 0 deletions crates/proof-of-sql/src/sql/proof_plans/cross_join_exec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
use super::DynProofPlan;
use crate::{
base::{
database::{
join_util::cross_join, ColumnField, ColumnRef, OwnedTable, Table,
TableEvaluation, TableOptions, TableRef,
},
map::{IndexMap, IndexSet},
proof::ProofError,
scalar::Scalar,
},
sql::proof::{
CountBuilder, FinalRoundBuilder, FirstRoundBuilder, ProofPlan, ProverEvaluate,
VerificationBuilder,
},
};
use alloc::{boxed::Box, vec, vec::Vec};
use bumpalo::Bump;
use core::iter::repeat_with;
use serde::{Deserialize, Serialize};

/// `ProofPlan` for queries of the form
/// ```ignore
/// <ProofPlan> JOIN <ProofPlan>
/// ```
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct CrossJoinExec {
pub(super) left: Box<DynProofPlan>,
pub(super) right: Box<DynProofPlan>,
}

impl CrossJoinExec {
/// Create a new `CrossJoinExec` with the given left and right plans
pub fn new(left: Box<DynProofPlan>, right: Box<DynProofPlan>) -> Self {
Self { left, right }
}
}

impl ProofPlan for CrossJoinExec
where
CrossJoinExec: ProverEvaluate,
{
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> {
self.left.count(builder)?;
self.right.count(builder)?;
Ok(())
}

#[allow(unused_variables)]
fn verifier_evaluate<S: Scalar>(
&self,
builder: &mut VerificationBuilder<S>,
accessor: &IndexMap<ColumnRef, S>,
_result: Option<&OwnedTable<S>>,
one_eval_map: &IndexMap<TableRef, S>,
) -> Result<TableEvaluation<S>, ProofError> {
// 1. columns
// TODO: Make sure `GroupByExec` as self.input is supported
let left_eval = self
.left
.verifier_evaluate(builder, accessor, None, one_eval_map)?;
let right_eval = self
.right
.verifier_evaluate(builder, accessor, None, one_eval_map)?;
let output_one_eval = builder.consume_one_eval();
todo!()
}

fn get_column_result_fields(&self) -> Vec<ColumnField> {
self.left
.get_column_result_fields()
.into_iter()
.chain(self.right.get_column_result_fields())
.collect()
}

fn get_column_references(&self) -> IndexSet<ColumnRef> {
self.left
.get_column_references()
.into_iter()
.chain(self.right.get_column_references())
.collect()
}

fn get_table_references(&self) -> IndexSet<TableRef> {
self.left
.get_table_references()
.into_iter()
.chain(self.right.get_table_references())
.collect()
}
}

impl ProverEvaluate for CrossJoinExec {
#[tracing::instrument(name = "CrossJoinExec::result_evaluate", level = "debug", skip_all)]
fn result_evaluate<'a, S: Scalar>(
&self,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
let (left, left_one_eval_lengths) = self.left.result_evaluate(alloc, table_map);
let (right, right_one_eval_lengths) = self.right.result_evaluate(alloc, table_map);
let res = cross_join(left, right, alloc);
let one_eval_lengths = input_one_eval_lengths
.into_iter()
.chain(right_one_eval_lengths.into_iter())
.chain(core::iter::once(left.num_rows() * right.num_rows()))
.collect();
(res, one_eval_lengths)
}

fn first_round_evaluate(&self, builder: &mut FirstRoundBuilder) {
self.left.first_round_evaluate(builder);
self.right.first_round_evaluate(builder);
}

#[tracing::instrument(name = "CrossJoinExec::prover_evaluate", level = "debug", skip_all)]
#[allow(unused_variables)]
fn final_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FinalRoundBuilder<'a, S>,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> Table<'a, S> {
let left = self.left.prover_evaluate(builder, alloc, table_map);
let right = self.right.prover_evaluate(builder, alloc, table_map);
todo!()
}
}
7 changes: 6 additions & 1 deletion crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{EmptyExec, FilterExec, GroupByExec, ProjectionExec, TableExec};
use super::{CrossJoinExec, EmptyExec, FilterExec, GroupByExec, ProjectionExec, TableExec};
use crate::{
base::{
database::{ColumnField, ColumnRef, OwnedTable, Table, TableEvaluation, TableRef},
Expand Down Expand Up @@ -43,4 +43,9 @@ pub enum DynProofPlan {
/// SELECT <result_expr1>, ..., <result_exprN> FROM <table> WHERE <where_clause>
/// ```
Filter(FilterExec),
/// [`ProofPlan`] for queries of the form
/// ```ignore
/// <ProofPlan> JOIN <ProofPlan>
/// ```
CrossJoin(CrossJoinExec),
}
3 changes: 3 additions & 0 deletions crates/proof-of-sql/src/sql/proof_plans/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,8 @@ pub(crate) use group_by_exec::GroupByExec;
#[cfg(all(test, feature = "blitzar"))]
mod group_by_exec_test;

mod cross_join_exec;
pub(crate) use cross_join_exec::CrossJoinExec;

mod dyn_proof_plan;
pub use dyn_proof_plan::DynProofPlan;

0 comments on commit a9fce18

Please sign in to comment.