Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(553): Optimize incremental join evaluation #557

Merged
merged 1 commit into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion crates/core/src/sql/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,15 @@ mod tests {
use std::ops::Bound;

use crate::db::relational_db::tests_utils::make_test_db;
use crate::host::module_host::{DatabaseTableUpdate, TableOp};
use crate::subscription::query;
use spacetimedb_lib::error::ResultTest;
use spacetimedb_lib::operator::OpQuery;
use spacetimedb_primitives::TableId;
use spacetimedb_sats::data_key::ToDataKey;
use spacetimedb_sats::db::def::{ColumnDef, IndexDef, TableDef};
use spacetimedb_sats::AlgebraicType;
use spacetimedb_sats::relation::MemTable;
use spacetimedb_sats::{product, AlgebraicType};
use spacetimedb_vm::expr::{IndexJoin, IndexScan, JoinExpr, Query};

fn create_table(
Expand Down Expand Up @@ -1025,4 +1029,96 @@ mod tests {
};
Ok(())
}

#[test]
fn compile_incremental_index_join() -> ResultTest<()> {
let (db, _) = make_test_db()?;
let mut tx = db.begin_tx();

// Create table [lhs] with index on [b]
let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
let indexes = &[(1.into(), "b")];
let lhs_id = create_table(&db, &mut tx, "lhs", schema, indexes)?;

// Create table [rhs] with index on [b, c]
let schema = &[
("b", AlgebraicType::U64),
("c", AlgebraicType::U64),
("d", AlgebraicType::U64),
];
let indexes = &[(0.into(), "b"), (1.into(), "c")];
let rhs_id = create_table(&db, &mut tx, "rhs", schema, indexes)?;

// Should generate an index join since there is an index on `lhs.b`.
// Should push the sargable range condition into the index join's probe side.
let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3";
let exp = compile_sql(&db, &tx, sql)?.remove(0);

let CrudExpr::Query(expr) = exp else {
panic!("unexpected result from compilation: {:#?}", exp);
};

// Create an insert for an incremental update.
let row = product!(0u64, 0u64);
let insert = TableOp {
op_type: 1,
row_pk: row.to_data_key().to_bytes(),
row,
};
let insert = DatabaseTableUpdate {
table_id: lhs_id,
table_name: String::from("lhs"),
ops: vec![insert],
};

// Optimize the query plan for the incremental update.
let expr = query::to_mem_table(expr, &insert);
let expr = expr.optimize();

let QueryExpr {
source:
SourceExpr::MemTable(MemTable {
head: Header { table_name, .. },
..
}),
query,
..
} = expr
else {
panic!("unexpected result after optimization: {:#?}", expr);
};

assert_eq!(table_name, "lhs");
assert_eq!(query.len(), 1);

let Query::IndexJoin(IndexJoin {
probe_side:
QueryExpr {
source: SourceExpr::MemTable(_),
query: ref lhs,
},
probe_field:
FieldName::Name {
table: ref probe_table,
field: ref probe_field,
},
index_header: _,
index_select: Some(_),
index_table,
index_col,
return_index_rows: false,
}) = query[0]
else {
panic!("unexpected operator {:#?}", query[0]);
};

assert!(lhs.is_empty());

// Assert that original index and probe tables have been swapped.
assert_eq!(index_table, rhs_id);
assert_eq!(index_col, 0.into());
assert_eq!(probe_field, "b");
assert_eq!(probe_table, "lhs");
Ok(())
}
}
5 changes: 3 additions & 2 deletions crates/core/src/subscription/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,13 +499,14 @@ mod tests {
let indexes = &[(0.into(), "id")];
let lhs_id = create_table(&db, &mut tx, "lhs", schema, indexes)?;

// Create table [rhs] with no indexes
// Create table [rhs] with index on [id]
let schema = &[
("rid", AlgebraicType::I32),
("id", AlgebraicType::I32),
("y", AlgebraicType::I32),
];
let rhs_id = create_table(&db, &mut tx, "rhs", schema, &[])?;
let indexes = &[(1.into(), "id")];
let rhs_id = create_table(&db, &mut tx, "rhs", schema, indexes)?;

// Insert into lhs
for i in 0..5 {
Expand Down
8 changes: 6 additions & 2 deletions crates/core/src/subscription/subscription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,9 @@ impl<'a> IncrementalJoin<'a> {
auth: &AuthCtx,
) -> Result<impl Iterator<Item = Op>, DBError> {
let mut inserts = {
let lhs_virt = query::to_mem_table(self.expr.clone(), &self.lhs.inserts());
// Replan query after replacing left table with virtual table,
// since join order may need to be reversed.
let lhs_virt = query::to_mem_table(self.expr.clone(), &self.lhs.inserts()).optimize();
let rhs_virt = self.to_mem_table_rhs(self.rhs.inserts());

// {A+ join B}
Expand All @@ -551,7 +553,9 @@ impl<'a> IncrementalJoin<'a> {
set
};
let mut deletes = {
let lhs_virt = query::to_mem_table(self.expr.clone(), &self.lhs.deletes());
// Replan query after replacing left table with virtual table,
// since join order may need to be reversed.
let lhs_virt = query::to_mem_table(self.expr.clone(), &self.lhs.deletes()).optimize();
let rhs_virt = self.to_mem_table_rhs(self.rhs.deletes());

// {A- join B}
Expand Down
80 changes: 64 additions & 16 deletions crates/core/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_primitives::{ColId, TableId};
use spacetimedb_sats::db::auth::{StAccess, StTableType};
use spacetimedb_sats::db::def::{ColumnDef, IndexDef, ProductTypeMeta, TableDef};
use spacetimedb_sats::relation::{
DbTable, FieldExpr, FieldName, Header, MemTable, RelIter, RelValue, Relation, RowCount, Table,
};
use spacetimedb_sats::relation::{DbTable, FieldExpr, FieldName, RelValueRef, Relation};
use spacetimedb_sats::relation::{Header, MemTable, RelIter, RelValue, RowCount, Table};
use spacetimedb_sats::{AlgebraicValue, ProductValue};
use spacetimedb_vm::env::EnvDb;
use spacetimedb_vm::errors::ErrorVm;
Expand Down Expand Up @@ -53,13 +52,31 @@ pub fn build_query<'a>(
let iter = result.select(move |row| cmp.compare(row, &header));
Box::new(iter)
}
// If this is an index join between two virtual tables, replace with an inner join.
// Such a plan is possible under incremental evaluation,
// when there are updates to both base tables,
// however an index lookup is invalid on a virtual table.
//
// TODO: This logic should be entirely encapsulated within the query planner.
// It should not be possible for the planner to produce an invalid plan.
Query::IndexJoin(join)
if !db_table
&& matches!(join.probe_side.source, SourceExpr::MemTable(_))
&& join.probe_side.source.table_name() != result.head().table_name =>
{
let join: JoinExpr = join.into();
let iter = join_inner(ctx, stdb, tx, result, join, true)?;
Box::new(iter)
}
Query::IndexJoin(IndexJoin {
probe_side,
probe_field,
index_header,
index_select,
index_table,
index_col,
}) if db_table => {
return_index_rows,
}) => {
let probe_side = build_query(ctx, stdb, tx, probe_side.into())?;
Box::new(IndexSemiJoin {
ctx,
Expand All @@ -68,16 +85,13 @@ pub fn build_query<'a>(
probe_side,
probe_field,
index_header,
index_select,
index_table,
index_col,
index_iter: None,
return_index_rows,
})
}
Query::IndexJoin(join) => {
let join: JoinExpr = join.into();
let iter = join_inner(ctx, stdb, tx, result, join, true)?;
Box::new(iter)
}
Query::Select(cmp) => {
let header = result.head().clone();
let iter = result.select(move |row| cmp.compare(row, &header));
Expand Down Expand Up @@ -189,12 +203,15 @@ pub struct IndexSemiJoin<'a, Rhs: RelOps> {
// The field whose value will be used to probe the index.
pub probe_field: FieldName,
// The header for the index side of the join.
// Also the return header since we are returning values from the index side.
pub index_header: Header,
// An optional predicate to evaluate over the matching rows of the index.
pub index_select: Option<ColumnOp>,
// The table id on which the index is defined.
pub index_table: TableId,
// The column id for which the index is defined.
pub index_col: ColId,
// Is this a left or right semijion?
pub return_index_rows: bool,
// An iterator for the index side.
// A new iterator will be instantiated for each row on the probe side.
pub index_iter: Option<IterByColEq<'a>>,
Expand All @@ -206,9 +223,32 @@ pub struct IndexSemiJoin<'a, Rhs: RelOps> {
ctx: &'a ExecutionContext<'a>,
}

impl<'a, Rhs: RelOps> IndexSemiJoin<'a, Rhs> {
fn filter(&self, index_row: RelValueRef) -> Result<bool, ErrorVm> {
if let Some(op) = &self.index_select {
Ok(op.compare(index_row, &self.index_header)?)
} else {
Ok(true)
}
}

fn map(&self, index_row: RelValue, probe_row: Option<RelValue>) -> RelValue {
if let Some(value) = probe_row {
if !self.return_index_rows {
return value;
}
}
index_row
}
}

impl<'a, Rhs: RelOps> RelOps for IndexSemiJoin<'a, Rhs> {
fn head(&self) -> &Header {
&self.index_header
if self.return_index_rows {
&self.index_header
} else {
self.probe_side.head()
}
}

fn row_count(&self) -> RowCount {
Expand All @@ -218,8 +258,13 @@ impl<'a, Rhs: RelOps> RelOps for IndexSemiJoin<'a, Rhs> {
#[tracing::instrument(skip_all)]
fn next(&mut self) -> Result<Option<RelValue>, ErrorVm> {
// Return a value from the current index iterator, if not exhausted.
if let Some(value) = self.index_iter.as_mut().and_then(|iter| iter.next()) {
return Ok(Some(value.to_rel_value()));
if self.return_index_rows {
while let Some(value) = self.index_iter.as_mut().and_then(|iter| iter.next()) {
let value = value.to_rel_value();
if self.filter(value.as_val_ref())? {
return Ok(Some(self.map(value, None)));
}
}
}
// Otherwise probe the index with a row from the probe side.
while let Some(row) = self.probe_side.next()? {
Expand All @@ -229,9 +274,12 @@ impl<'a, Rhs: RelOps> RelOps for IndexSemiJoin<'a, Rhs> {
let col_id = self.index_col;
let value = value.clone();
let mut index_iter = self.db.iter_by_col_eq(self.ctx, self.tx, table_id, col_id, value)?;
if let Some(value) = index_iter.next() {
self.index_iter = Some(index_iter);
return Ok(Some(value.to_rel_value()));
while let Some(value) = index_iter.next() {
let value = value.to_rel_value();
if self.filter(value.as_val_ref())? {
self.index_iter = Some(index_iter);
return Ok(Some(self.map(value, Some(row))));
}
}
}
}
Expand Down
Loading
Loading