diff --git a/crates/core/src/sql/compiler.rs b/crates/core/src/sql/compiler.rs index 89443af1cf3..78f49ec045f 100644 --- a/crates/core/src/sql/compiler.rs +++ b/crates/core/src/sql/compiler.rs @@ -295,11 +295,15 @@ mod tests { use std::ops::Bound; use crate::db::relational_db::tests_utils::make_test_db; + use crate::host::module_host::{DatabaseTableUpdate, TableOp}; + use crate::subscription::query; use spacetimedb_lib::error::ResultTest; use spacetimedb_lib::operator::OpQuery; use spacetimedb_primitives::TableId; + use spacetimedb_sats::data_key::ToDataKey; use spacetimedb_sats::db::def::{ColumnDef, IndexDef, TableDef}; - use spacetimedb_sats::AlgebraicType; + use spacetimedb_sats::relation::MemTable; + use spacetimedb_sats::{product, AlgebraicType}; use spacetimedb_vm::expr::{IndexJoin, IndexScan, JoinExpr, Query}; fn create_table( @@ -1025,4 +1029,96 @@ mod tests { }; Ok(()) } + + #[test] + fn compile_incremental_index_join() -> ResultTest<()> { + let (db, _) = make_test_db()?; + let mut tx = db.begin_tx(); + + // Create table [lhs] with index on [b] + let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)]; + let indexes = &[(1.into(), "b")]; + let lhs_id = create_table(&db, &mut tx, "lhs", schema, indexes)?; + + // Create table [rhs] with index on [b, c] + let schema = &[ + ("b", AlgebraicType::U64), + ("c", AlgebraicType::U64), + ("d", AlgebraicType::U64), + ]; + let indexes = &[(0.into(), "b"), (1.into(), "c")]; + let rhs_id = create_table(&db, &mut tx, "rhs", schema, indexes)?; + + // Should generate an index join since there is an index on `lhs.b`. + // Should push the sargable range condition into the index join's probe side. + let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3"; + let exp = compile_sql(&db, &tx, sql)?.remove(0); + + let CrudExpr::Query(expr) = exp else { + panic!("unexpected result from compilation: {:#?}", exp); + }; + + // Create an insert for an incremental update. + let row = product!(0u64, 0u64); + let insert = TableOp { + op_type: 1, + row_pk: row.to_data_key().to_bytes(), + row, + }; + let insert = DatabaseTableUpdate { + table_id: lhs_id, + table_name: String::from("lhs"), + ops: vec![insert], + }; + + // Optimize the query plan for the incremental update. + let expr = query::to_mem_table(expr, &insert); + let expr = expr.optimize(); + + let QueryExpr { + source: + SourceExpr::MemTable(MemTable { + head: Header { table_name, .. }, + .. + }), + query, + .. + } = expr + else { + panic!("unexpected result after optimization: {:#?}", expr); + }; + + assert_eq!(table_name, "lhs"); + assert_eq!(query.len(), 1); + + let Query::IndexJoin(IndexJoin { + probe_side: + QueryExpr { + source: SourceExpr::MemTable(_), + query: ref lhs, + }, + probe_field: + FieldName::Name { + table: ref probe_table, + field: ref probe_field, + }, + index_header: _, + index_select: Some(_), + index_table, + index_col, + return_index_rows: false, + }) = query[0] + else { + panic!("unexpected operator {:#?}", query[0]); + }; + + assert!(lhs.is_empty()); + + // Assert that original index and probe tables have been swapped. + assert_eq!(index_table, rhs_id); + assert_eq!(index_col, 0.into()); + assert_eq!(probe_field, "b"); + assert_eq!(probe_table, "lhs"); + Ok(()) + } } diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index d0faef062c4..3bcef2056d8 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -499,13 +499,14 @@ mod tests { let indexes = &[(0.into(), "id")]; let lhs_id = create_table(&db, &mut tx, "lhs", schema, indexes)?; - // Create table [rhs] with no indexes + // Create table [rhs] with index on [id] let schema = &[ ("rid", AlgebraicType::I32), ("id", AlgebraicType::I32), ("y", AlgebraicType::I32), ]; - let rhs_id = create_table(&db, &mut tx, "rhs", schema, &[])?; + let indexes = &[(1.into(), "id")]; + let rhs_id = create_table(&db, &mut tx, "rhs", schema, indexes)?; // Insert into lhs for i in 0..5 { diff --git a/crates/core/src/subscription/subscription.rs b/crates/core/src/subscription/subscription.rs index 301c074c851..7c763547117 100644 --- a/crates/core/src/subscription/subscription.rs +++ b/crates/core/src/subscription/subscription.rs @@ -527,7 +527,9 @@ impl<'a> IncrementalJoin<'a> { auth: &AuthCtx, ) -> Result, DBError> { let mut inserts = { - let lhs_virt = query::to_mem_table(self.expr.clone(), &self.lhs.inserts()); + // Replan query after replacing left table with virtual table, + // since join order may need to be reversed. + let lhs_virt = query::to_mem_table(self.expr.clone(), &self.lhs.inserts()).optimize(); let rhs_virt = self.to_mem_table_rhs(self.rhs.inserts()); // {A+ join B} @@ -551,7 +553,9 @@ impl<'a> IncrementalJoin<'a> { set }; let mut deletes = { - let lhs_virt = query::to_mem_table(self.expr.clone(), &self.lhs.deletes()); + // Replan query after replacing left table with virtual table, + // since join order may need to be reversed. + let lhs_virt = query::to_mem_table(self.expr.clone(), &self.lhs.deletes()).optimize(); let rhs_virt = self.to_mem_table_rhs(self.rhs.deletes()); // {A- join B} diff --git a/crates/core/src/vm.rs b/crates/core/src/vm.rs index ebf3646089b..0050518f271 100644 --- a/crates/core/src/vm.rs +++ b/crates/core/src/vm.rs @@ -9,9 +9,8 @@ use spacetimedb_lib::identity::AuthCtx; use spacetimedb_primitives::{ColId, TableId}; use spacetimedb_sats::db::auth::{StAccess, StTableType}; use spacetimedb_sats::db::def::{ColumnDef, IndexDef, ProductTypeMeta, TableDef}; -use spacetimedb_sats::relation::{ - DbTable, FieldExpr, FieldName, Header, MemTable, RelIter, RelValue, Relation, RowCount, Table, -}; +use spacetimedb_sats::relation::{DbTable, FieldExpr, FieldName, RelValueRef, Relation}; +use spacetimedb_sats::relation::{Header, MemTable, RelIter, RelValue, RowCount, Table}; use spacetimedb_sats::{AlgebraicValue, ProductValue}; use spacetimedb_vm::env::EnvDb; use spacetimedb_vm::errors::ErrorVm; @@ -53,13 +52,31 @@ pub fn build_query<'a>( let iter = result.select(move |row| cmp.compare(row, &header)); Box::new(iter) } + // If this is an index join between two virtual tables, replace with an inner join. + // Such a plan is possible under incremental evaluation, + // when there are updates to both base tables, + // however an index lookup is invalid on a virtual table. + // + // TODO: This logic should be entirely encapsulated within the query planner. + // It should not be possible for the planner to produce an invalid plan. + Query::IndexJoin(join) + if !db_table + && matches!(join.probe_side.source, SourceExpr::MemTable(_)) + && join.probe_side.source.table_name() != result.head().table_name => + { + let join: JoinExpr = join.into(); + let iter = join_inner(ctx, stdb, tx, result, join, true)?; + Box::new(iter) + } Query::IndexJoin(IndexJoin { probe_side, probe_field, index_header, + index_select, index_table, index_col, - }) if db_table => { + return_index_rows, + }) => { let probe_side = build_query(ctx, stdb, tx, probe_side.into())?; Box::new(IndexSemiJoin { ctx, @@ -68,16 +85,13 @@ pub fn build_query<'a>( probe_side, probe_field, index_header, + index_select, index_table, index_col, index_iter: None, + return_index_rows, }) } - Query::IndexJoin(join) => { - let join: JoinExpr = join.into(); - let iter = join_inner(ctx, stdb, tx, result, join, true)?; - Box::new(iter) - } Query::Select(cmp) => { let header = result.head().clone(); let iter = result.select(move |row| cmp.compare(row, &header)); @@ -189,12 +203,15 @@ pub struct IndexSemiJoin<'a, Rhs: RelOps> { // The field whose value will be used to probe the index. pub probe_field: FieldName, // The header for the index side of the join. - // Also the return header since we are returning values from the index side. pub index_header: Header, + // An optional predicate to evaluate over the matching rows of the index. + pub index_select: Option, // The table id on which the index is defined. pub index_table: TableId, // The column id for which the index is defined. pub index_col: ColId, + // Is this a left or right semijion? + pub return_index_rows: bool, // An iterator for the index side. // A new iterator will be instantiated for each row on the probe side. pub index_iter: Option>, @@ -206,9 +223,32 @@ pub struct IndexSemiJoin<'a, Rhs: RelOps> { ctx: &'a ExecutionContext<'a>, } +impl<'a, Rhs: RelOps> IndexSemiJoin<'a, Rhs> { + fn filter(&self, index_row: RelValueRef) -> Result { + if let Some(op) = &self.index_select { + Ok(op.compare(index_row, &self.index_header)?) + } else { + Ok(true) + } + } + + fn map(&self, index_row: RelValue, probe_row: Option) -> RelValue { + if let Some(value) = probe_row { + if !self.return_index_rows { + return value; + } + } + index_row + } +} + impl<'a, Rhs: RelOps> RelOps for IndexSemiJoin<'a, Rhs> { fn head(&self) -> &Header { - &self.index_header + if self.return_index_rows { + &self.index_header + } else { + self.probe_side.head() + } } fn row_count(&self) -> RowCount { @@ -218,8 +258,13 @@ impl<'a, Rhs: RelOps> RelOps for IndexSemiJoin<'a, Rhs> { #[tracing::instrument(skip_all)] fn next(&mut self) -> Result, ErrorVm> { // Return a value from the current index iterator, if not exhausted. - if let Some(value) = self.index_iter.as_mut().and_then(|iter| iter.next()) { - return Ok(Some(value.to_rel_value())); + if self.return_index_rows { + while let Some(value) = self.index_iter.as_mut().and_then(|iter| iter.next()) { + let value = value.to_rel_value(); + if self.filter(value.as_val_ref())? { + return Ok(Some(self.map(value, None))); + } + } } // Otherwise probe the index with a row from the probe side. while let Some(row) = self.probe_side.next()? { @@ -229,9 +274,12 @@ impl<'a, Rhs: RelOps> RelOps for IndexSemiJoin<'a, Rhs> { let col_id = self.index_col; let value = value.clone(); let mut index_iter = self.db.iter_by_col_eq(self.ctx, self.tx, table_id, col_id, value)?; - if let Some(value) = index_iter.next() { - self.index_iter = Some(index_iter); - return Ok(Some(value.to_rel_value())); + while let Some(value) = index_iter.next() { + let value = value.to_rel_value(); + if self.filter(value.as_val_ref())? { + self.index_iter = Some(index_iter); + return Ok(Some(self.map(value, Some(row)))); + } } } } diff --git a/crates/vm/src/expr.rs b/crates/vm/src/expr.rs index 963a9c7552c..f2a6ab06f68 100644 --- a/crates/vm/src/expr.rs +++ b/crates/vm/src/expr.rs @@ -280,6 +280,16 @@ impl From for ColumnOp { } } +impl From for Option { + fn from(value: Query) -> Self { + match value { + Query::IndexScan(op) => Some(op.into()), + Query::Select(op) => Some(op), + _ => None, + } + } +} + #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, From)] pub enum SourceExpr { MemTable(MemTable), @@ -382,14 +392,16 @@ impl From<&SourceExpr> for DbTable { } // A descriptor for an index join operation. -// The semantics are that of a semi-join with rows from the index side being returned. +// The semantics are those of a semijoin with rows from the index or the probe side being returned. #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord)] pub struct IndexJoin { pub probe_side: QueryExpr, pub probe_field: FieldName, pub index_header: Header, + pub index_select: Option, pub index_table: TableId, pub index_col: ColId, + pub return_index_rows: bool, } #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord)] @@ -1150,8 +1162,10 @@ impl QueryExpr { probe_side, probe_field, index_header: table.head.clone(), + index_select: None, index_table: table.table_id, index_col: col.col_id, + return_index_rows: true, }; return QueryExpr { source, @@ -1217,6 +1231,97 @@ impl QueryExpr { q } + // Is this an incremental evaluation of an index join {L+ join R} + fn is_incremental_index_join(&self) -> bool { + if self.query.len() != 1 { + return false; + } + // Is this in index join? + let Query::IndexJoin(IndexJoin { + probe_side: + QueryExpr { + source: SourceExpr::DbTable(rhs_table), + query: selections, + }, + probe_field, + index_select: None, + return_index_rows: true, + .. + }) = &self.query[0] + else { + return false; + }; + // Is this an incremental evaluation of updates to the left hand table? + let SourceExpr::MemTable(_) = self.source else { + return false; + }; + // Does the right hand table have an index on the join field? + let Some(Column { is_indexed: true, .. }) = rhs_table.head.column(probe_field) else { + return false; + }; + // The original probe side must consist of an optional index scan, + // followed by an arbitrary number of selections. + selections + .iter() + .all(|op| matches!(op, Query::Select(_)) || matches!(op, Query::IndexScan(_))) + } + + // Assuming this is an incremental evaluation of an index join {L+ join R}, + // swap the index and probe sides to avoid scanning all of R. + fn optimize_incremental_index_join(mut self) -> Option { + // This is an index join. + let Some(Query::IndexJoin(IndexJoin { + probe_side: + QueryExpr { + source: SourceExpr::DbTable(rhs_table), + query: selections, + }, + probe_field, + index_header, + index_table: _, + index_col, + index_select: None, + return_index_rows: true, + })) = self.query.pop() + else { + return None; + }; + // This is an incremental evaluation of updates to the left hand table. + let SourceExpr::MemTable(index_side_updates) = self.source else { + return None; + }; + let index_column = index_header.fields.iter().find(|column| column.col_id == index_col)?; + let probe_column = rhs_table.head.column(&probe_field)?; + // Merge all selections from the original probe side into a single predicate. + // This includes an index scan if present. + let predicate = selections.iter().cloned().fold(None, |acc, op| { + >>::into(op).map(|op| { + if let Some(predicate) = acc { + ColumnOp::new(OpQuery::Logic(OpLogic::And), predicate, op) + } else { + op + } + }) + }); + Some(IndexJoin { + // The new probe side consists of the updated rows. + probe_side: index_side_updates.into(), + // The new probe field is the previous index field. + probe_field: index_column.field.clone(), + // The original probe table is now the table that is being probed. + index_header: rhs_table.head.clone(), + // Any selections from the original probe side are pulled above the index lookup. + index_select: predicate, + // The original probe table is now the table that is being probed. + index_table: rhs_table.table_id, + // The new index field is the previous probe field. + index_col: probe_column.col_id, + // Because we have swapped the original index and probe sides of the join, + // the new index join needs to return rows from the probe side instead of the index side. + return_index_rows: false, + }) + } + pub fn optimize(self) -> Self { let mut q = Self { source: self.source.clone(), @@ -1229,6 +1334,14 @@ impl QueryExpr { .flat_map(|x| x.into_iter()) .collect(); + if self.is_incremental_index_join() { + // The above check guarantees that the optimization will succeed, + // and therefore it is safe to unwrap. + let index_join = self.optimize_incremental_index_join().unwrap(); + q.query.push(Query::IndexJoin(index_join)); + return q; + } + for query in self.query { match query { Query::Select(op) => { @@ -1751,8 +1864,10 @@ mod tests { table_name: "bar".into(), fields: vec![], }, + index_select: None, index_table: 42.into(), index_col: 22.into(), + return_index_rows: true, }), Query::JoinInner(JoinExpr { rhs: mem_table.into(),