diff --git a/crates/lib/src/error.rs b/crates/lib/src/error.rs index 0485a9eff3..938d52645c 100644 --- a/crates/lib/src/error.rs +++ b/crates/lib/src/error.rs @@ -112,6 +112,8 @@ pub enum AuthError { IndexPrivate { named: String }, #[error("Sequence `{named}` is private")] SequencePrivate { named: String }, + #[error("Only the database owner can perform the requested operation")] + OwnerRequired, } #[derive(thiserror::Error, Debug)] diff --git a/crates/lib/src/identity.rs b/crates/lib/src/identity.rs index 682d387603..f86166807b 100644 --- a/crates/lib/src/identity.rs +++ b/crates/lib/src/identity.rs @@ -36,7 +36,7 @@ impl Identity { const ABBREVIATION_LEN: usize = 16; /// Returns an `Identity` defined as the given `bytes` byte array. - pub fn from_byte_array(bytes: [u8; 32]) -> Self { + pub const fn from_byte_array(bytes: [u8; 32]) -> Self { Self { __identity_bytes: bytes, } diff --git a/crates/lib/src/relation.rs b/crates/lib/src/relation.rs index c4599952cf..1193a63879 100644 --- a/crates/lib/src/relation.rs +++ b/crates/lib/src/relation.rs @@ -602,7 +602,7 @@ impl Relation for DbTable { } } -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, From)] pub enum Table { MemTable(MemTable), DbTable(DbTable), diff --git a/crates/vm/src/expr.rs b/crates/vm/src/expr.rs index b3c5bbd7a3..175373cacf 100644 --- a/crates/vm/src/expr.rs +++ b/crates/vm/src/expr.rs @@ -11,7 +11,7 @@ use spacetimedb_sats::algebraic_value::AlgebraicValue; use spacetimedb_sats::satn::Satn; use spacetimedb_sats::{ProductValue, Typespace, WithTypespace}; use std::cmp::Ordering; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::fmt; use std::ops::Bound; @@ -510,7 +510,7 @@ impl Ord for IndexScan { } // An individual operation in a query. -#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord)] +#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, From)] pub enum Query { // Fetching rows via an index. IndexScan(IndexScan), @@ -531,6 +531,20 @@ pub enum Query { JoinInner(JoinExpr), } +impl Query { + /// Iterate over all [`SourceExpr`]s involved in the [`Query`]. + /// + /// Sources are yielded from left to right. Duplicates are not filtered out. + pub fn sources(&self) -> QuerySources { + match self { + Self::Select(..) | Self::Project(..) => QuerySources::None, + Self::IndexScan(scan) => QuerySources::One(Some(scan.table.clone().into())), + Self::IndexJoin(join) => QuerySources::Expr(join.probe_side.sources()), + Self::JoinInner(join) => QuerySources::Expr(join.rhs.sources()), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct QueryExpr { pub source: SourceExpr, @@ -555,6 +569,26 @@ impl From for QueryExpr { } } +/// Iterator created by the [`Query::sources`] method. +#[must_use = "iterators are lazy and do nothing unless consumed"] +pub enum QuerySources { + None, + One(Option), + Expr(QueryExprSources), +} + +impl Iterator for QuerySources { + type Item = SourceExpr; + + fn next(&mut self) -> Option { + match self { + Self::None => None, + Self::One(src) => src.take(), + Self::Expr(expr) => expr.next(), + } + } +} + impl QueryExpr { pub fn new>(source: T) -> Self { Self { @@ -563,6 +597,16 @@ impl QueryExpr { } } + /// Iterate over all [`SourceExpr`]s involved in the [`QueryExpr`]. + /// + /// Sources are yielded from left to right. Duplicates are not filtered out. + pub fn sources(&self) -> QueryExprSources { + QueryExprSources { + head: Some(self.source.clone()), + tail: self.query.iter().map(Query::sources).collect(), + } + } + // Generate an index scan for an equality predicate if this is the first operator. // Otherwise generate a select. // TODO: Replace these methods with a proper query optimization pass. @@ -943,21 +987,48 @@ impl QueryExpr { } } +/// Iterator created by the [`QueryExpr::sources`] method. +#[must_use = "iterators are lazy and do nothing unless consumed"] +pub struct QueryExprSources { + head: Option, + tail: VecDeque, +} + +impl Iterator for QueryExprSources { + type Item = SourceExpr; + + fn next(&mut self) -> Option { + self.head.take().or_else(|| { + while let Some(cur) = self.tail.front_mut() { + match cur.next() { + None => { + self.tail.pop_front(); + continue; + } + Some(src) => return Some(src), + } + } + + None + }) + } +} + impl AuthAccess for Query { fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> { if owner == caller { - Ok(()) - } else if let Query::JoinInner(j) = self { - if j.rhs.source.table_access() == StAccess::Public { - Ok(()) - } else { - Err(AuthError::TablePrivate { - named: j.rhs.source.table_name().to_string(), - }) + return Ok(()); + } + + for table in self.sources() { + if table.table_access() == StAccess::Private { + return Err(AuthError::TablePrivate { + named: table.table_name().to_owned(), + }); } - } else { - Ok(()) } + + Ok(()) } } // @@ -1245,18 +1316,11 @@ impl AuthAccess for QueryCode { return Ok(()); } self.table.check_auth(owner, caller)?; - - if let Some(err) = self.query.iter().find_map(|x| { - if let Err(err) = x.check_auth(owner, caller) { - Some(err) - } else { - None - } - }) { - Err(err) - } else { - Ok(()) + for q in &self.query { + q.check_auth(owner, caller)?; } + + Ok(()) } } @@ -1302,41 +1366,14 @@ impl AuthAccess for CrudCode { if owner == caller { return Ok(()); } - match self { - CrudCode::Query(q) => q.check_auth(owner, caller), - CrudCode::Insert { table, .. } => table.check_auth(owner, caller), - CrudCode::Update { insert, delete } => { - insert.check_auth(owner, caller)?; - delete.check_auth(owner, caller) - } - CrudCode::Delete { query, .. } => query.check_auth(owner, caller), - //TODO: Must allow to create private tables for `caller` - CrudCode::CreateTable { name, table_access, .. } => { - if table_access == &StAccess::Public { - Ok(()) - } else { - Err(AuthError::TablePrivate { - named: name.to_string(), - }) - } - } - CrudCode::Drop { - name, - kind, - table_access, - } => { - if table_access == &StAccess::Public { - Ok(()) - } else { - let named = name.to_string(); - Err(match kind { - DbType::Table => AuthError::TablePrivate { named }, - DbType::Index => AuthError::IndexPrivate { named }, - DbType::Sequence => AuthError::SequencePrivate { named }, - }) - } - } + + // Anyone may query, so as long as the tables involved are public. + if let CrudCode::Query(q) = self { + return q.check_auth(owner, caller); } + + // Mutating operations require `owner == caller`. + Err(AuthError::OwnerRequired) } } @@ -1401,3 +1438,179 @@ impl From for CodeResult { } } } + +#[cfg(test)] +mod tests { + use spacetimedb_sats::ProductType; + + use super::*; + + const ALICE: Identity = Identity::from_byte_array([1; 32]); + const BOB: Identity = Identity::from_byte_array([2; 32]); + + // TODO(kim): Should better do property testing here, but writing generators + // on recursive types (ie. `Query` and friends) is tricky. + + fn tables() -> [Table; 2] { + [ + Table::MemTable(MemTable { + head: Header { + table_name: "foo".into(), + fields: vec![], + }, + data: vec![], + table_access: StAccess::Private, + }), + Table::DbTable(DbTable { + head: Header { + table_name: "foo".into(), + fields: vec![], + }, + table_id: 42, + table_type: StTableType::User, + table_access: StAccess::Private, + }), + ] + } + + fn queries() -> impl IntoIterator { + let [Table::MemTable(mem_table), Table::DbTable(db_table)] = tables() else { + unreachable!() + }; + // Skip `Query::Select` and `QueryProject` -- they don't have table + // information + [ + Query::IndexScan(IndexScan { + table: db_table, + col_id: 42, + lower_bound: Bound::Included(22.into()), + upper_bound: Bound::Unbounded, + }), + Query::IndexJoin(IndexJoin { + probe_side: mem_table.clone().into(), + probe_field: FieldName::Name { + table: "foo".into(), + field: "bar".into(), + }, + index_header: Header { + table_name: "bar".into(), + fields: vec![], + }, + index_table: 42, + index_col: 22, + }), + Query::JoinInner(JoinExpr { + rhs: mem_table.into(), + col_rhs: FieldName::Name { + table: "foo".into(), + field: "id".into(), + }, + col_lhs: FieldName::Name { + table: "bar".into(), + field: "id".into(), + }, + }), + ] + } + + fn query_codes() -> impl IntoIterator { + tables().map(|table| { + let expr = match table { + Table::DbTable(table) => QueryExpr::from(table), + Table::MemTable(table) => QueryExpr::from(table), + }; + let mut code = QueryCode::from(expr); + code.query = queries().into_iter().collect(); + code + }) + } + + fn assert_owner_private(auth: &T) { + assert!(auth.check_auth(ALICE, ALICE).is_ok()); + assert!(matches!( + auth.check_auth(ALICE, BOB), + Err(AuthError::TablePrivate { .. }) + )); + } + + fn assert_owner_required(auth: T) { + assert!(auth.check_auth(ALICE, ALICE).is_ok()); + assert!(matches!(auth.check_auth(ALICE, BOB), Err(AuthError::OwnerRequired))); + } + + #[test] + fn test_auth_table() { + tables().iter().for_each(assert_owner_private) + } + + #[test] + fn test_auth_query_code() { + for code in query_codes() { + assert_owner_private(&code) + } + } + + #[test] + fn test_auth_query() { + for query in queries() { + assert_owner_private(&query); + } + } + + #[test] + fn test_auth_crud_code_query() { + for query in query_codes() { + let crud = CrudCode::Query(query); + assert_owner_private(&crud); + } + } + + #[test] + fn test_auth_crud_code_insert() { + for table in tables() { + let crud = CrudCode::Insert { table, rows: vec![] }; + assert_owner_required(crud); + } + } + + #[test] + fn test_auth_crud_code_update() { + let mut qc = query_codes().into_iter(); + let insert = qc.next().unwrap(); + let delete = qc.next().unwrap(); + let crud = CrudCode::Update { insert, delete }; + assert_owner_required(crud); + } + + #[test] + fn test_auth_crud_code_delete() { + for query in query_codes() { + let crud = CrudCode::Delete { query }; + assert_owner_required(crud); + } + } + + #[test] + fn test_auth_crud_code_create_table() { + let crud = CrudCode::CreateTable { + name: "etcpasswd".into(), + columns: ProductTypeMeta { + columns: ProductType { elements: vec![] }, + attr: vec![], + }, + table_type: StTableType::System, // hah! + table_access: StAccess::Public, + }; + assert_owner_required(crud); + } + + #[test] + fn test_auth_crud_code_drop() { + let crud = CrudCode::Drop { + name: "etcpasswd".into(), + kind: DbType::Table, + table_access: StAccess::Public, + }; + assert_owner_required(crud); + } +}