Skip to content

Commit

Permalink
Adding a new filter! macro
Browse files Browse the repository at this point in the history
  • Loading branch information
mamcx committed Oct 14, 2024
1 parent df5b78a commit 6e1b654
Show file tree
Hide file tree
Showing 17 changed files with 272 additions and 22 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/bindings-macro/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ bench = false

[dependencies]
spacetimedb-primitives.workspace = true
spacetimedb-sql-parser.workspace = true

bitflags.workspace = true
humantime.workspace = true
Expand Down
72 changes: 72 additions & 0 deletions crates/bindings-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::borrow::Cow;
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::time::Duration;
use syn::ext::IdentExt;
use syn::meta::ParseNestedMeta;
Expand Down Expand Up @@ -1241,3 +1242,74 @@ pub fn schema_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}

fn parse_sql(input: ParseStream) -> syn::Result<String> {
use spacetimedb_sql_parser::parser::sub;

let lookahead = input.lookahead1();
let sql = if lookahead.peek(syn::LitStr) {
let s = input.parse::<syn::LitStr>()?;
// Checks the query is syntactically valid
let _ = sub::parse_subscription(&s.value()).map_err(|e| syn::Error::new(s.span(), format_args!("{e}")))?;

s.value()
} else {
return Err(lookahead.error());
};

Ok(sql)
}

/// Generates code for registering a row-level security `SQL` function.
///
/// A row-level security function takes a `SQL` query expression that is used to filter rows.
///
/// The query follows the same syntax as a subscription query.
///
/// **Example:**
///
/// ```rust,ignore
/// /// Players can only see what's in their chunk
/// spacetimedb::filter!("
/// SELECT * FROM LocationState WHERE chunk_index IN (
/// SELECT chunk_index FROM LocationState WHERE entity_id IN (
/// SELECT entity_id FROM UserState WHERE identity = @sender
/// )
/// )
/// ");
/// ```
///
/// **NOTE:** The `SQL` query expression is pre-parsed at compile time, but only check is a valid
/// subscription query *syntactically*, not that the query is valid when executed.
///
/// For example, it could refer to a non-existent table.
#[proc_macro]
pub fn filter(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let rls_sql = syn::parse_macro_input!(input with parse_sql);

let mut hasher = DefaultHasher::new();
rls_sql.hash(&mut hasher);
let rls_name = format_ident!("rls_{}", hasher.finish());

let register_rls_symbol = format!("__preinit__20_register_{rls_name}");

let generated_describe_function = quote! {
#[export_name = #register_rls_symbol]
extern "C" fn __register_rls() {
spacetimedb::rt::register_row_level_security::<#rls_name>()
}
};

let emission = quote! {
const _: () = {
#generated_describe_function
};
#[allow(non_camel_case_types)]
struct #rls_name { _never: ::core::convert::Infallible }
impl spacetimedb::rt::RowLevelSecurityInfo for #rls_name {
const SQL: &'static str = #rls_sql;
}
};

emission.into()
}
2 changes: 1 addition & 1 deletion crates/bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub use rng::StdbRng;
pub use sats::SpacetimeType;
#[doc(hidden)]
pub use spacetimedb_bindings_macro::__TableHelper;
pub use spacetimedb_bindings_macro::{duration, reducer, table};
pub use spacetimedb_bindings_macro::{duration, filter, reducer, table};
pub use spacetimedb_bindings_sys as sys;
pub use spacetimedb_lib;
pub use spacetimedb_lib::de::{Deserialize, DeserializeOwned};
Expand Down
13 changes: 13 additions & 0 deletions crates/bindings/src/rt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ impl RepeaterArgs for (Timestamp,) {
}
}

/// A trait for types that can *describe* a row-level security policy.
pub trait RowLevelSecurityInfo {
/// The SQL expression for the row-level security policy.
const SQL: &'static str;
}

/// Registers into `DESCRIBERS` a function `f` to modify the module builder.
fn register_describer(f: fn(&mut ModuleBuilder)) {
DESCRIBERS.lock().unwrap().push(f)
Expand Down Expand Up @@ -283,6 +289,13 @@ pub fn register_reducer<'a, A: Args<'a>, I: ReducerInfo>(_: impl Reducer<'a, A>)
})
}

/// Registers a row-level security policy.
pub fn register_row_level_security<R: RowLevelSecurityInfo>() {
register_describer(|module| {
module.inner.add_row_level_security(R::SQL);
})
}

/// A builder for a module.
#[derive(Default)]
struct ModuleBuilder {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
source: crates/bindings/tests/deps.rs
expression: "cargo tree -p spacetimedb -f {lib} -e no-dev"
---
total crates: 62
total crates: 64
spacetimedb
├── bytemuck
├── derive_more
Expand Down Expand Up @@ -48,6 +48,15 @@ spacetimedb
│ │ ├── itertools
│ │ │ └── either
│ │ └── nohash_hasher
│ ├── spacetimedb_sql_parser
│ │ ├── derive_more (*)
│ │ ├── sqlparser
│ │ │ └── log
│ │ └── thiserror
│ │ └── thiserror_impl
│ │ ├── proc_macro2 (*)
│ │ ├── quote (*)
│ │ └── syn (*)
│ └── syn (*)
├── spacetimedb_bindings_sys
│ └── spacetimedb_primitives (*)
Expand All @@ -74,11 +83,7 @@ spacetimedb
│ │ │ └── allocator_api2
│ │ ├── nohash_hasher
│ │ ├── smallvec
│ │ └── thiserror
│ │ └── thiserror_impl
│ │ ├── proc_macro2 (*)
│ │ ├── quote (*)
│ │ └── syn (*)
│ │ └── thiserror (*)
│ ├── spacetimedb_primitives (*)
│ ├── spacetimedb_sats
│ │ ├── arrayvec
Expand Down
28 changes: 25 additions & 3 deletions crates/core/src/db/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ use super::datastore::locking_tx_datastore::MutTxId;
use super::relational_db::RelationalDB;
use crate::database_logger::SystemLogger;
use crate::execution_context::ExecutionContext;
use crate::sql::parser::RowLevelExpr;
use spacetimedb_data_structures::map::HashMap;
use spacetimedb_lib::db::auth::StTableType;
use spacetimedb_lib::db::raw_def::v9::RawRowLevelSecurityDefV9;
use spacetimedb_lib::AlgebraicValue;
use spacetimedb_primitives::ColSet;
use spacetimedb_schema::auto_migrate::{AutoMigratePlan, ManualMigratePlan, MigratePlan};
use spacetimedb_schema::def::TableDef;
use spacetimedb_schema::def::{ModuleDefLookup, TableDef};
use spacetimedb_schema::schema::{IndexSchema, Schema, SequenceSchema, TableSchema};
use std::sync::Arc;

Expand Down Expand Up @@ -105,6 +107,9 @@ fn auto_migrate_database(
}
}
}
// Is necessary to collect the full list of old and new tables to pass to the `RowLevelExpr::try_from` method,
// because we removed all the old row-level security definitions, then added the new ones.
let mut all_tables = table_schemas_by_name.clone();

log::info!("Running database update steps: {}", stdb.address());

Expand All @@ -115,11 +120,14 @@ fn auto_migrate_database(

// Recursively sets IDs to 0.
// They will be initialized by the database when the table is created.
let table_schema = TableSchema::from_module_def(plan.new, table_def, (), 0.into());
let mut table_schema = TableSchema::from_module_def(plan.new, table_def, (), 0.into());

system_logger.info(&format!("Creating table `{}`", table_name));
log::info!("Creating table `{}`", table_name);
stdb.create_table(tx, table_schema)?;

let table_id = stdb.create_table(tx, table_schema.clone())?;
table_schema.table_id = table_id;
all_tables.insert(table_schema.table_name.clone(), Arc::new(table_schema));
}
spacetimedb_schema::auto_migrate::AutoMigrateStep::AddIndex(index_name) => {
let table_def = plan.new.stored_in_table_def(index_name).unwrap();
Expand Down Expand Up @@ -221,6 +229,20 @@ fn auto_migrate_database(
spacetimedb_schema::auto_migrate::AutoMigrateStep::RemoveSchedule(_) => {
anyhow::bail!("Removing schedules is not yet implemented");
}
spacetimedb_schema::auto_migrate::AutoMigrateStep::AddRowLevelSecurity(sql_rls) => {
system_logger.info(&format!("Adding row-level security `{sql_rls}`"));
log::info!("Adding row-level security `{sql_rls}`");
let tables = all_tables.values().cloned().collect::<Vec<_>>();
let rls = RawRowLevelSecurityDefV9::lookup(plan.new, sql_rls).unwrap();
let rls = RowLevelExpr::try_from((rls, tables.as_slice()))?;

stdb.create_row_level_security(tx, rls.def)?;
}
spacetimedb_schema::auto_migrate::AutoMigrateStep::RemoveRowLevelSecurity(sql_rls) => {
system_logger.info(&format!("Removing-row level security `{sql_rls}`"));
log::info!("Removing row-level security `{sql_rls}`");
stdb.drop_row_level_security(tx, sql_rls.clone())?;
}
}
}

Expand Down
22 changes: 20 additions & 2 deletions crates/core/src/host/wasm_common/module_host_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::identity::Identity;
use crate::messages::control_db::HostType;
use crate::module_host_context::ModuleCreationContext;
use crate::replica_context::ReplicaContext;
use crate::sql::parser::RowLevelExpr;
use crate::subscription::module_subscription_actor::WriteConflict;
use crate::util::const_unwrap;
use crate::util::prometheus_handle::HistogramExt;
Expand Down Expand Up @@ -268,13 +269,30 @@ impl<T: WasmInstance> ModuleInstance for WasmModuleInstance<T> {
.with_auto_rollback(&ctx, tx, |tx| {
let mut table_defs: Vec<_> = self.info.module_def.tables().collect();
table_defs.sort_by(|a, b| a.name.cmp(&b.name));
let mut table_schemas = Vec::with_capacity(table_defs.len());

for def in table_defs {
let table_name = &def.name;
self.system_logger().info(&format!("Creating table `{table_name}`"));
let schema = TableSchema::from_module_def(&self.info.module_def, def, (), TableId::SENTINEL);
stdb.create_table(tx, schema)
let mut schema = TableSchema::from_module_def(&self.info.module_def, def, (), TableId::SENTINEL);
let table_id = stdb
.create_table(tx, schema.clone())
.with_context(|| format!("failed to create table {table_name}"))?;
schema.table_id = table_id;
table_schemas.push(schema.into());
}
// Insert the late-bound row-level security expressions.
for rls in self.info.module_def.row_level_security() {
self.system_logger()
.info(&format!("Creating row level security `{}`", rls.sql));

let rls = RowLevelExpr::try_from((rls, table_schemas.as_slice()))
.with_context(|| format!("failed to create row-level security: `{}`", rls.sql))?;
let table_id = rls.def.table_id;
let sql = rls.def.sql.clone();
stdb.create_row_level_security(tx, rls.def).with_context(|| {
format!("failed to create row-level security for table `{table_id}`: `{sql}`",)
})?;
}

stdb.set_initialized(tx, HostType::Wasm, program)?;
Expand Down
1 change: 1 addition & 0 deletions crates/core/src/sql/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod ast;
pub mod compiler;
pub mod execute;
pub mod parser;
mod type_check;
29 changes: 29 additions & 0 deletions crates/core/src/sql/parser.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use spacetimedb_expr::check::parse_and_type_sub;
use spacetimedb_expr::errors::TypingError;
use spacetimedb_expr::expr::RelExpr;
use spacetimedb_expr::ty::TyCtx;
use spacetimedb_lib::db::raw_def::v9::RawRowLevelSecurityDefV9;
use spacetimedb_schema::schema::{RowLevelSecuritySchema, TableSchema};
use std::sync::Arc;

pub struct RowLevelExpr {
pub sql: RelExpr,
pub def: RowLevelSecuritySchema,
}

impl TryFrom<(&RawRowLevelSecurityDefV9, &[Arc<TableSchema>])> for RowLevelExpr {
type Error = TypingError;

fn try_from((rls, tx): (&RawRowLevelSecurityDefV9, &[Arc<TableSchema>])) -> Result<Self, Self::Error> {
let mut ctx = TyCtx::default();
let sql = parse_and_type_sub(&mut ctx, &rls.sql, &tx)?;

Ok(Self {
def: RowLevelSecuritySchema {
table_id: sql.table_id(&mut ctx)?,
sql: rls.sql.clone(),
},
sql,
})
}
}
5 changes: 5 additions & 0 deletions crates/expr/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ pub trait SchemaView {
fn schema(&self, name: &str) -> Option<Arc<TableSchema>>;
}

impl SchemaView for &[Arc<TableSchema>] {
fn schema(&self, name: &str) -> Option<Arc<TableSchema>> {
self.iter().find(|schema| schema.table_name == Box::from(name)).cloned()
}
}
pub trait TypeChecker {
type Ast;
type Set;
Expand Down
12 changes: 9 additions & 3 deletions crates/expr/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use spacetimedb_sql_parser::{ast::BinOp, parser::errors::SqlParseError};
use thiserror::Error;

use super::{
statement::InvalidVar,
ty::{InvalidTypeId, TypeWithCtx},
};
use spacetimedb_sql_parser::ast::BinOp;
use spacetimedb_sql_parser::parser::errors::SqlParseError;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum Unresolved {
Expand Down Expand Up @@ -134,6 +134,10 @@ impl UnexpectedType {
#[error("Duplicate name `{0}`")]
pub struct DuplicateName(pub String);

#[derive(Debug, Error)]
#[error("No `TableId` found in `sql` expression")]
pub struct NoTableId;

#[derive(Error, Debug)]
pub enum TypingError {
#[error(transparent)]
Expand Down Expand Up @@ -163,4 +167,6 @@ pub enum TypingError {
Wildcard(#[from] InvalidWildcard),
#[error(transparent)]
DuplicateName(#[from] DuplicateName),
#[error(transparent)]
NoTableId(#[from] NoTableId),
}
14 changes: 11 additions & 3 deletions crates/expr/src/expr.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::sync::Arc;

use crate::errors::{NoTableId, TypingError};
use crate::static_assert_size;
use spacetimedb_lib::AlgebraicValue;
use spacetimedb_primitives::TableId;
use spacetimedb_schema::schema::TableSchema;
use spacetimedb_sql_parser::ast::BinOp;

use crate::static_assert_size;

use super::ty::{InvalidTypeId, Symbol, TyCtx, TyId, TypeWithCtx};
use super::ty::{InvalidTypeId, Symbol, TyCtx, TyId, Type, TypeWithCtx};

/// A logical relational expression
#[derive(Debug)]
Expand Down Expand Up @@ -54,6 +55,13 @@ impl RelExpr {
pub fn ty<'a>(&self, ctx: &'a TyCtx) -> Result<TypeWithCtx<'a>, InvalidTypeId> {
ctx.try_resolve(self.ty_id())
}

pub fn table_id(&self, ctx: &mut TyCtx) -> Result<TableId, TypingError> {
match &*self.ty(ctx)? {
Type::Var(id, _) => Ok(*id),
_ => Err(NoTableId.into()),
}
}
}

/// A relational select operation or filter
Expand Down
Loading

0 comments on commit 6e1b654

Please sign in to comment.