Skip to content

Commit

Permalink
feat: generalize sum_expr in GroupByExpr
Browse files Browse the repository at this point in the history
- fix bug in `AggregateExpr`
  • Loading branch information
iajoiner committed Jun 30, 2024
1 parent 24b3592 commit e48f938
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 117 deletions.
4 changes: 2 additions & 2 deletions crates/proof-of-sql/src/sql/ast/aggregate_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ impl<C: Commitment> AggregateExpr<C> {
}

impl<C: Commitment> ProvableExpr<C> for AggregateExpr<C> {
fn count(&self, _builder: &mut CountBuilder) -> Result<(), ProofError> {
Ok(())
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> {
self.expr.count(builder)
}

fn data_type(&self) -> ColumnType {
Expand Down
46 changes: 24 additions & 22 deletions crates/proof-of-sql/src/sql/ast/group_by_expr.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use super::{
aggregate_columns, fold_columns, fold_vals,
group_by_util::{compare_indexes_by_owned_columns, AggregatedColumns},
provable_expr_plan::ProvableExprPlan,
ColumnExpr, ProvableExpr, TableExpr,
AliasedProvableExprPlan, ColumnExpr, ProvableExpr, ProvableExprPlan, TableExpr,
};
use crate::{
base::{
Expand Down Expand Up @@ -30,7 +29,7 @@ use std::collections::HashSet;
/// Provable expressions for queries of the form
/// ```ignore
/// SELECT <group_by_expr1>, ..., <group_by_exprM>,
/// SUM(<sum_expr1>.0) as <sum_expr1>.1, ..., SUM(<sum_exprN>.0) as <sum_exprN>.1,
/// SUM(<sum_expr1>.expr) as <sum_expr1>.alias, ..., SUM(<sum_exprN>.expr) as <sum_exprN>.alias,
/// COUNT(*) as count_alias
/// FROM <table>
/// WHERE <where_clause>
Expand All @@ -41,7 +40,7 @@ use std::collections::HashSet;
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct GroupByExpr<C: Commitment> {
pub(super) group_by_exprs: Vec<ColumnExpr<C>>,
pub(super) sum_expr: Vec<(ColumnExpr<C>, ColumnField)>,
pub(super) sum_expr: Vec<AliasedProvableExprPlan<C>>,
pub(super) count_alias: Identifier,
pub(super) table: TableExpr,
pub(super) where_clause: ProvableExprPlan<C>,
Expand All @@ -51,7 +50,7 @@ impl<C: Commitment> GroupByExpr<C> {
/// Creates a new group_by expression.
pub fn new(
group_by_exprs: Vec<ColumnExpr<C>>,
sum_expr: Vec<(ColumnExpr<C>, ColumnField)>,
sum_expr: Vec<AliasedProvableExprPlan<C>>,
count_alias: Identifier,
table: TableExpr,
where_clause: ProvableExprPlan<C>,
Expand All @@ -77,8 +76,8 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
expr.count(builder)?;
builder.count_result_columns(1);
}
for expr in self.sum_expr.iter() {
expr.0.count(builder)?;
for aliased_expr in self.sum_expr.iter() {
aliased_expr.expr.count(builder)?;
builder.count_result_columns(1);
}
builder.count_result_columns(1);
Expand Down Expand Up @@ -115,7 +114,7 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
let aggregate_evals = self
.sum_expr
.iter()
.map(|expr| expr.0.verifier_evaluate(builder, accessor))
.map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor))
.collect::<Result<Vec<_>, _>>()?;
// 3. indexes
let indexes_eval = builder
Expand Down Expand Up @@ -169,15 +168,17 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
}

fn get_column_result_fields(&self) -> Vec<ColumnField> {
let mut fields = Vec::new();
for col in self.group_by_exprs.iter() {
fields.push(col.get_column_field());
}
for col in self.sum_expr.iter() {
fields.push(col.1);
}
fields.push(ColumnField::new(self.count_alias, ColumnType::BigInt));
fields
self.group_by_exprs
.iter()
.map(|col| col.get_column_field())
.chain(self.sum_expr.iter().map(|aliased_expr| {
ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type())
}))
.chain(std::iter::once(ColumnField::new(
self.count_alias,
ColumnType::BigInt,
)))
.collect()
}

fn get_column_references(&self) -> HashSet<ColumnRef> {
Expand All @@ -186,8 +187,8 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
for col in self.group_by_exprs.iter() {
columns.insert(col.get_column_reference());
}
for col in self.sum_expr.iter() {
columns.insert(col.0.get_column_reference());
for aliased_expr in self.sum_expr.iter() {
aliased_expr.expr.get_column_references(&mut columns);
}

self.where_clause.get_column_references(&mut columns);
Expand Down Expand Up @@ -219,8 +220,9 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for GroupByExpr<C> {
.iter()
.map(|expr| expr.result_evaluate(builder.table_length(), alloc, accessor)),
);
let sum_columns = Vec::from_iter(self.sum_expr.iter().map(|expr| {
expr.0
let sum_columns = Vec::from_iter(self.sum_expr.iter().map(|aliased_expr| {
aliased_expr
.expr
.result_evaluate(builder.table_length(), alloc, accessor)
}));
// Compute filtered_columns and indexes
Expand Down Expand Up @@ -267,7 +269,7 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for GroupByExpr<C> {
let sum_columns = Vec::from_iter(
self.sum_expr
.iter()
.map(|expr| expr.0.prover_evaluate(builder, alloc, accessor)),
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)),
);
// Compute filtered_columns and indexes
let AggregatedColumns {
Expand Down
63 changes: 38 additions & 25 deletions crates/proof-of-sql/src/sql/ast/group_by_expr_test.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use super::test_utility::{
and, cols_expr, column, const_int128, const_varchar, equal, group_by, sums_expr, tab,
};
use super::test_utility::*;
use crate::{
base::{
commitment::InnerProductProof,
database::{owned_table_utility::*, ColumnType, OwnedTableTestAccessor, TestAccessor},
database::{owned_table_utility::*, OwnedTableTestAccessor, TestAccessor},
scalar::Curve25519Scalar,
},
sql::proof::{exercise_verification, VerifiableQueryResult},
};

/// select a, sum(c * 2 + 1) as sum_c, count(*) as __count__ from sxt.t where b = 99 group by a
#[test]
fn we_can_prove_a_simple_group_by_with_bigint_columns() {
let data = owned_table([
Expand All @@ -22,7 +21,13 @@ fn we_can_prove_a_simple_group_by_with_bigint_columns() {
accessor.add_table(t, data, 0);
let expr = group_by(
cols_expr(t, &["a"], &accessor),
sums_expr(t, &["c"], &["sum_c"], &[ColumnType::BigInt], &accessor),
vec![sum_expr(
add(
multiply(column(t, "c", &accessor), const_bigint(2)),
const_bigint(1),
),
"sum_c",
)],
"__count__",
tab(t),
equal(column(t, "b", &accessor), const_int128(99)),
Expand All @@ -32,7 +37,7 @@ fn we_can_prove_a_simple_group_by_with_bigint_columns() {
let res = res.verify(&expr, &accessor, &()).unwrap().table;
let expected = owned_table([
bigint("a", [1, 2]),
bigint("sum_c", [101 + 104, 102 + 103]),
bigint("sum_c", [(101 + 104) * 2 + 2, (102 + 103) * 2 + 2]),
bigint("__count__", [2, 2]),
]);
assert_eq!(res, expected);
Expand Down Expand Up @@ -118,7 +123,7 @@ fn we_can_prove_a_complex_group_by_query_with_many_columns() {
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
accessor.add_table(t, data, 0);

// SELECT scalar_group, int128_group, bigint_group, sum(int128_filter) as sum_int, sum(bigint_filter) as sum_bigint, sum(scalar_filter) as sum_scal, count(*) as __count__
// SELECT scalar_group, int128_group, bigint_group, sum(bigint_sum + 1) as sum_int, sum(bigint_sum - int128_sum) as sum_bigint, sum(scalar_filter) as sum_scal, count(*) as __count__
// FROM sxt.t WHERE int128_filter = 1020 AND varchar_filter = 'f2'
// GROUP BY scalar_group, int128_group, bigint_group
let expr = group_by(
Expand All @@ -127,13 +132,20 @@ fn we_can_prove_a_complex_group_by_query_with_many_columns() {
&["scalar_group", "int128_group", "bigint_group"],
&accessor,
),
sums_expr(
t,
&["bigint_sum", "int128_sum", "scalar_sum"],
&["sum_int", "sum_128", "sum_scal"],
&[ColumnType::BigInt, ColumnType::Int128, ColumnType::Scalar],
&accessor,
),
vec![
sum_expr(
add(column(t, "bigint_sum", &accessor), const_bigint(1)),
"sum_int",
),
sum_expr(
subtract(
column(t, "bigint_sum", &accessor),
column(t, "int128_sum", &accessor),
),
"sum_128",
),
sum_expr(column(t, "scalar_sum", &accessor), "sum_scal"),
],
"__count__",
tab(t),
and(
Expand All @@ -148,24 +160,25 @@ fn we_can_prove_a_complex_group_by_query_with_many_columns() {
scalar("scalar_group", [4, 4, 4]),
int128("int128_group", [8, 8, 9]),
bigint("bigint_group", [6, 7, 6]),
bigint("sum_int", [1406, 927, 637]),
int128("sum_128", [1342, 1262, 513]),
bigint("sum_int", [1409, 929, 638]),
int128("sum_128", [64, -335, 124]),
scalar("sum_scal", [1116, 1033, 375]),
bigint("__count__", [3, 2, 1]),
]);
assert_eq!(res, expected);

// SELECT sum(int128_filter) as sum_int, sum(bigint_filter) as sum_bigint, sum(scalar_filter) as sum_scal, count(*) as __count__
// SELECT sum(bigint_sum) as sum_int, sum(int128_sum * 4) as sum_128, sum(scalar_sum) as sum_scal, count(*) as __count__
// FROM sxt.t WHERE int128_filter = 1020 AND varchar_filter = 'f2'
let expr = group_by(
vec![],
sums_expr(
t,
&["bigint_sum", "int128_sum", "scalar_sum"],
&["sum_int", "sum_128", "sum_scal"],
&[ColumnType::BigInt, ColumnType::Int128, ColumnType::Scalar],
&accessor,
),
vec![
sum_expr(column(t, "bigint_sum", &accessor), "sum_int"),
sum_expr(
multiply(column(t, "int128_sum", &accessor), const_bigint(4)),
"sum_128",
),
sum_expr(column(t, "scalar_sum", &accessor), "sum_scal"),
],
"__count__",
tab(t),
and(
Expand All @@ -178,7 +191,7 @@ fn we_can_prove_a_complex_group_by_query_with_many_columns() {
let res = res.verify(&expr, &accessor, &()).unwrap().table;
let expected = owned_table([
bigint("sum_int", [1406 + 927 + 637]),
int128("sum_128", [1342 + 1262 + 513]),
int128("sum_128", [(1342 + 1262 + 513) * 4]),
scalar("sum_scal", [1116 + 1033 + 375]),
bigint("__count__", [3 + 2 + 1]),
]);
Expand Down
34 changes: 9 additions & 25 deletions crates/proof-of-sql/src/sql/ast/test_utility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ use super::{
};
use crate::base::{
commitment::Commitment,
database::{ColumnField, ColumnRef, ColumnType, LiteralValue, SchemaAccessor, TableRef},
database::{ColumnRef, LiteralValue, SchemaAccessor, TableRef},
math::decimal::Precision,
};
use proof_of_sql_parser::intermediate_ast::AggregationOperator;

pub fn col_ref(tab: TableRef, name: &str, accessor: &impl SchemaAccessor) -> ColumnRef {
let name = name.parse().unwrap();
Expand Down Expand Up @@ -237,35 +238,18 @@ pub fn dense_filter<C: Commitment>(
}

pub fn sum_expr<C: Commitment>(
tab: TableRef,
name: &str,
expr: ProvableExprPlan<C>,
alias: &str,
column_type: ColumnType,
accessor: &impl SchemaAccessor,
) -> (ColumnExpr<C>, ColumnField) {
(
col_expr(tab, name, accessor),
ColumnField::new(alias.parse().unwrap(), column_type),
)
}

pub fn sums_expr<C: Commitment>(
tab: TableRef,
names: &[&str],
aliases: &[&str],
column_types: &[ColumnType],
accessor: &impl SchemaAccessor,
) -> Vec<(ColumnExpr<C>, ColumnField)> {
names
.iter()
.zip(aliases.iter().zip(column_types.iter()))
.map(|(name, (alias, column_type))| sum_expr(tab, name, alias, *column_type, accessor))
.collect()
) -> AliasedProvableExprPlan<C> {
AliasedProvableExprPlan {
expr: ProvableExprPlan::new_aggregate(AggregationOperator::Sum, expr),
alias: alias.parse().unwrap(),
}
}

pub fn group_by<C: Commitment>(
group_by_exprs: Vec<ColumnExpr<C>>,
sum_expr: Vec<(ColumnExpr<C>, ColumnField)>,
sum_expr: Vec<AliasedProvableExprPlan<C>>,
count_alias: &str,
table: TableExpr,
where_clause: ProvableExprPlan<C>,
Expand Down
34 changes: 14 additions & 20 deletions crates/proof-of-sql/src/sql/parse/query_context.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::{
base::{
commitment::Commitment,
database::{ColumnField, ColumnRef, ColumnType, LiteralValue, TableRef},
database::{ColumnRef, ColumnType, LiteralValue, TableRef},
},
sql::{
ast::{ColumnExpr, GroupByExpr, ProvableExprPlan, TableExpr},
parse::{ConversionError, ConversionResult, WhereExprBuilder},
ast::{AliasedProvableExprPlan, ColumnExpr, GroupByExpr, ProvableExprPlan, TableExpr},
parse::{ConversionError, ConversionResult, ProvableExprPlanBuilder, WhereExprBuilder},
},
};
use proof_of_sql_parser::{
Expand Down Expand Up @@ -251,7 +251,7 @@ impl<C: Commitment> TryFrom<&QueryContext> for Option<GroupByExpr<C>> {
.collect::<Result<Vec<ColumnExpr<C>>, ConversionError>>()?;
// For a query to be provable the result columns must be of one of three kinds below:
// 1. Group by columns (it is mandatory to have all of them in the correct order)
// 2. Sum(col) expressions (it is optional to have any)
// 2. Sum(expr) expressions (it is optional to have any)
// 3. count(*) with an alias (it is mandatory to have one and only one)
let num_group_by_columns = group_by_exprs.len();
let num_result_columns = value.res_aliased_exprs.len();
Expand Down Expand Up @@ -289,28 +289,22 @@ impl<C: Commitment> TryFrom<&QueryContext> for Option<GroupByExpr<C>> {
.map(|res| {
if let Expression::Aggregation {
op: AggregationOperator::Sum,
expr,
..
} = (*res.expr).clone()
{
if let Expression::Column(ident) = *expr {
// For sums the outgoing ColumnType is the same as the incoming ColumnType
let column_type = *value
.column_mapping
.get(&ident)
.expect("QueryContext should never allow unknown cols to be in sum")
.column_type();
let res_column_field = ColumnField::new(res.alias, column_type);
let column_expr =
ColumnExpr::new(ColumnRef::new(table.table_ref, ident, column_type));
Some((column_expr, res_column_field))
} else {
None
}
let res_provable_expr_plan =
ProvableExprPlanBuilder::new(&value.column_mapping).build(&res.expr);
res_provable_expr_plan
.ok()
.map(|provable_expr_plan| AliasedProvableExprPlan {
alias: res.alias,
expr: provable_expr_plan,
})
} else {
None
}
})
.collect::<Option<Vec<(ColumnExpr<C>, ColumnField)>>>();
.collect::<Option<Vec<AliasedProvableExprPlan<C>>>>();

// Check count(*)
let count_column = &value.res_aliased_exprs[num_result_columns - 1];
Expand Down
Loading

0 comments on commit e48f938

Please sign in to comment.