Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

incr-join, find_updates: avoid unncecessary clones & use partition #988

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 56 additions & 101 deletions crates/core/src/subscription/subscription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::ProductValue;
use spacetimedb_primitives::TableId;
use spacetimedb_sats::db::auth::{StAccess, StTableType};
use spacetimedb_sats::relation::{DbTable, Header};
use spacetimedb_sats::relation::DbTable;
use spacetimedb_vm::expr::{self, IndexJoin, Query, QueryExpr, SourceSet};
use spacetimedb_vm::rel_ops::RelOps;
use spacetimedb_vm::relation::MemTable;
Expand Down Expand Up @@ -209,32 +209,21 @@ pub struct IncrementalJoin {

/// One side of an [`IncrementalJoin`].
///
/// Holds the "physical" [`DbTable`] this side of the join operates on, as well
/// as the [`DatabaseTableUpdate`]s pertaining that table.
/// Holds the updates pertaining to a table on one side of the join.
struct JoinSide {
table_id: TableId,
table_name: String,
inserts: Vec<TableOp>,
deletes: Vec<TableOp>,
inserts: Vec<ProductValue>,
deletes: Vec<ProductValue>,
}

impl JoinSide {
/// Return a [`DatabaseTableUpdate`] consisting of only insert operations.
pub fn inserts(&self) -> DatabaseTableUpdate {
DatabaseTableUpdate {
table_id: self.table_id,
table_name: self.table_name.clone(),
ops: self.inserts.to_vec(),
}
/// Return a list of updates consisting of only insert operations.
pub fn inserts(&self) -> Vec<ProductValue> {
self.inserts.clone()
}

/// Return a [`DatabaseTableUpdate`] with only delete operations.
pub fn deletes(&self) -> DatabaseTableUpdate {
DatabaseTableUpdate {
table_id: self.table_id,
table_name: self.table_name.clone(),
ops: self.deletes.to_vec(),
}
/// Return a list of updates with only delete operations.
pub fn deletes(&self) -> Vec<ProductValue> {
self.deletes.clone()
}

/// Does this table update include inserts?
Expand All @@ -249,18 +238,6 @@ impl JoinSide {
}

impl IncrementalJoin {
/// Construct an empty [`DatabaseTableUpdate`] with the schema of `table`
/// to use as a source when pre-compiling `eval_incr` queries.
fn dummy_table_update(table: &DbTable) -> DatabaseTableUpdate {
let table_id = table.table_id;
let table_name = table.head.table_name.clone();
DatabaseTableUpdate {
table_id,
table_name,
ops: vec![],
}
}

fn optimize_query(join: IndexJoin) -> QueryExpr {
let expr = QueryExpr::from(join);
// Because (at least) one of the two tables will be a `MemTable`,
Expand Down Expand Up @@ -313,21 +290,15 @@ impl IncrementalJoin {
.context("expected a physical database table")?
.clone();

let (virtual_index_plan, _sources) =
with_delta_table(join.clone(), Some(Self::dummy_table_update(&index_table)), None);
let (virtual_index_plan, _sources) = with_delta_table(join.clone(), Some(Vec::new()), None);
debug_assert_eq!(_sources.len(), 1);
let virtual_index_plan = Self::optimize_query(virtual_index_plan);

let (virtual_probe_plan, _sources) =
with_delta_table(join.clone(), None, Some(Self::dummy_table_update(&probe_table)));
let (virtual_probe_plan, _sources) = with_delta_table(join.clone(), None, Some(Vec::new()));
debug_assert_eq!(_sources.len(), 1);
let virtual_probe_plan = Self::optimize_query(virtual_probe_plan);

let (virtual_plan, _sources) = with_delta_table(
join.clone(),
Some(Self::dummy_table_update(&index_table)),
Some(Self::dummy_table_update(&probe_table)),
);
let (virtual_plan, _sources) = with_delta_table(join.clone(), Some(Vec::new()), Some(Vec::new()));
debug_assert_eq!(_sources.len(), 2);
let virtual_plan = virtual_plan.to_inner_join();

Expand Down Expand Up @@ -360,46 +331,53 @@ impl IncrementalJoin {
&self,
updates: impl IntoIterator<Item = &'a DatabaseTableUpdate>,
) -> Option<(JoinSide, JoinSide)> {
let mut lhs_ops = Vec::new();
let mut rhs_ops = Vec::new();
let mut lhs_inserts = Vec::new();
let mut lhs_deletes = Vec::new();
let mut rhs_inserts = Vec::new();
let mut rhs_deletes = Vec::new();

// Partitions `updates` into `deletes` and `inserts`.
let partition_into = |deletes: &mut Vec<_>, inserts: &mut Vec<_>, updates: &DatabaseTableUpdate| {
for update in &updates.ops {
if update.op_type == 0 {
&mut *deletes
} else {
&mut *inserts
}
.push(update.row.clone());
}
};

// Partitions all updates into the `(l|r)hs_(insert|delete)_ops` above.
for update in updates {
if update.table_id == self.lhs.table_id {
lhs_ops.extend(update.ops.iter().cloned());
partition_into(&mut lhs_deletes, &mut lhs_inserts, update);
} else if update.table_id == self.rhs.table_id {
rhs_ops.extend(update.ops.iter().cloned());
partition_into(&mut rhs_deletes, &mut rhs_inserts, update);
}
}

if lhs_ops.is_empty() && rhs_ops.is_empty() {
// No updates at all? Return `None`.
if [&lhs_inserts, &lhs_deletes, &rhs_inserts, &rhs_deletes]
.iter()
.all(|ops| ops.is_empty())
{
return None;
}

let lhs = JoinSide {
table_id: self.lhs.table_id,
table_name: self.lhs.head.table_name.clone(),
inserts: lhs_ops.iter().filter(|op| op.op_type == 1).cloned().collect(),
deletes: lhs_ops.iter().filter(|op| op.op_type == 0).cloned().collect(),
};

let rhs = JoinSide {
table_id: self.rhs.table_id,
table_name: self.rhs.head.table_name.clone(),
inserts: rhs_ops.iter().filter(|op| op.op_type == 1).cloned().collect(),
deletes: rhs_ops.iter().filter(|op| op.op_type == 0).cloned().collect(),
};

Some((lhs, rhs))
// Stich together the `JoinSide`s.
let join_side = |deletes, inserts| JoinSide { deletes, inserts };
Some((join_side(lhs_deletes, lhs_inserts), join_side(rhs_deletes, rhs_inserts)))
}

/// Evaluate join plan for lhs updates.
fn eval_lhs(
&self,
db: &RelationalDB,
tx: &Tx,
lhs: DatabaseTableUpdate,
lhs: Vec<ProductValue>,
) -> Result<impl Iterator<Item = ProductValue>, DBError> {
let lhs = to_mem_table(self.lhs.head.clone(), self.lhs.table_access, lhs);
let lhs = MemTable::new(self.lhs.head.clone(), self.lhs.table_access, lhs);
let mut sources = SourceSet::default();
sources.add_mem_table(lhs);
eval_updates(db, tx, self.plan_for_delta_lhs(), sources)
Expand All @@ -410,9 +388,9 @@ impl IncrementalJoin {
&self,
db: &RelationalDB,
tx: &Tx,
rhs: DatabaseTableUpdate,
rhs: Vec<ProductValue>,
) -> Result<impl Iterator<Item = ProductValue>, DBError> {
let rhs = to_mem_table(self.rhs.head.clone(), self.rhs.table_access, rhs);
let rhs = MemTable::new(self.rhs.head.clone(), self.rhs.table_access, rhs);
let mut sources = SourceSet::default();
sources.add_mem_table(rhs);
eval_updates(db, tx, self.plan_for_delta_rhs(), sources)
Expand All @@ -423,11 +401,11 @@ impl IncrementalJoin {
&self,
db: &RelationalDB,
tx: &Tx,
lhs: DatabaseTableUpdate,
rhs: DatabaseTableUpdate,
lhs: Vec<ProductValue>,
rhs: Vec<ProductValue>,
) -> Result<impl Iterator<Item = ProductValue>, DBError> {
let lhs = to_mem_table(self.lhs.head.clone(), self.lhs.table_access, lhs);
let rhs = to_mem_table(self.rhs.head.clone(), self.rhs.table_access, rhs);
let lhs = MemTable::new(self.lhs.head.clone(), self.lhs.table_access, lhs);
let rhs = MemTable::new(self.rhs.head.clone(), self.rhs.table_access, rhs);
let mut sources = SourceSet::default();
let (index_side, probe_side) = if self.return_index_rows { (lhs, rhs) } else { (rhs, lhs) };
sources.add_mem_table(index_side);
Expand Down Expand Up @@ -571,39 +549,25 @@ impl IncrementalJoin {
}
}

/// Construct a [`MemTable`] containing the updates from `delta`,
/// which must be derived from a table with `head` and `table_access`.
fn to_mem_table(head: Arc<Header>, table_access: StAccess, delta: DatabaseTableUpdate) -> MemTable {
MemTable::new(
head,
table_access,
delta.ops.into_iter().map(|op| op.row).collect::<Vec<_>>(),
)
}

/// Replace an [IndexJoin]'s scan or fetch operation with a delta table.
/// A delta table consists purely of updates or changes to the base table.
fn with_delta_table(
mut join: IndexJoin,
index_side: Option<DatabaseTableUpdate>,
probe_side: Option<DatabaseTableUpdate>,
index_side: Option<Vec<ProductValue>>,
probe_side: Option<Vec<ProductValue>>,
) -> (IndexJoin, SourceSet) {
let mut sources = SourceSet::default();

if let Some(index_side) = index_side {
let head = join.index_side.head().clone();
let table_access = join.index_side.table_access();
let mem_table = to_mem_table(head, table_access, index_side);
let source_expr = sources.add_mem_table(mem_table);
join.index_side = source_expr;
join.index_side = sources.add_mem_table(MemTable::new(head, table_access, index_side));
}

if let Some(probe_side) = probe_side {
let head = join.probe_side.source.head().clone();
let table_access = join.probe_side.source.table_access();
let mem_table = to_mem_table(head, table_access, probe_side);
let source_expr = sources.add_mem_table(mem_table);
join.probe_side.source = source_expr;
join.probe_side.source = sources.add_mem_table(MemTable::new(head, table_access, probe_side));
}

(join, sources)
Expand Down Expand Up @@ -720,7 +684,6 @@ pub(crate) fn get_all(relational_db: &RelationalDB, tx: &Tx, auth: &AuthCtx) ->
mod tests {
use super::*;
use crate::db::relational_db::tests_utils::make_test_db;
use crate::host::module_host::TableOp;
use crate::sql::compiler::compile_sql;
use spacetimedb_lib::error::ResultTest;
use spacetimedb_sats::relation::{DbTable, FieldName};
Expand All @@ -736,7 +699,7 @@ mod tests {
// Create table [lhs] with index on [b]
let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
let indexes = &[(1.into(), "b")];
let lhs_id = db.create_table_for_test("lhs", schema, indexes)?;
let _ = db.create_table_for_test("lhs", schema, indexes)?;

// Create table [rhs] with index on [b, c]
let schema = &[
Expand Down Expand Up @@ -766,11 +729,7 @@ mod tests {
};

// Create an insert for an incremental update.
let delta = DatabaseTableUpdate {
table_id: lhs_id,
table_name: String::from("lhs"),
ops: vec![TableOp::insert(product![0u64, 0u64])],
};
let delta = vec![product![0u64, 0u64]];

// Optimize the query plan for the incremental update.
let (expr, _sources) = with_delta_table(join, Some(delta), None);
Expand Down Expand Up @@ -834,7 +793,7 @@ mod tests {
("d", AlgebraicType::U64),
];
let indexes = &[(0.into(), "b"), (1.into(), "c")];
let rhs_id = db.create_table_for_test("rhs", schema, indexes)?;
let _ = db.create_table_for_test("rhs", schema, indexes)?;

let tx = db.begin_tx();
// Should generate an index join since there is an index on `lhs.b`.
Expand All @@ -855,11 +814,7 @@ mod tests {
};

// Create an insert for an incremental update.
let delta = DatabaseTableUpdate {
table_id: rhs_id,
table_name: String::from("rhs"),
ops: vec![TableOp::insert(product![0u64, 0u64, 0u64])],
};
let delta = vec![product![0u64, 0u64, 0u64]];

// Optimize the query plan for the incremental update.
let (expr, _sources) = with_delta_table(join, None, Some(delta));
Expand Down
Loading