Skip to content

Commit

Permalink
RLS: Adding a new filter! macro (#1849)
Browse files Browse the repository at this point in the history
Signed-off-by: Mario Montoya <mamcx@elmalabarista.com>
Co-authored-by: joshua-spacetime <josh@clockworklabs.io>
  • Loading branch information
2 people authored and lcodes committed Oct 25, 2024
1 parent 989adac commit 22241f9
Show file tree
Hide file tree
Showing 16 changed files with 269 additions and 22 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/bindings-macro/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ bench = false

[dependencies]
spacetimedb-primitives.workspace = true
spacetimedb-sql-parser.workspace = true

bitflags.workspace = true
humantime.workspace = true
Expand Down
72 changes: 72 additions & 0 deletions crates/bindings-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use proc_macro::TokenStream as StdTokenStream;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::borrow::Cow;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::time::Duration;
use syn::ext::IdentExt;
use syn::meta::ParseNestedMeta;
Expand Down Expand Up @@ -1237,3 +1238,74 @@ pub fn schema_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}

fn parse_sql(input: ParseStream) -> syn::Result<String> {
use spacetimedb_sql_parser::parser::sub;

let lookahead = input.lookahead1();
let sql = if lookahead.peek(syn::LitStr) {
let s = input.parse::<syn::LitStr>()?;
// Checks the query is syntactically valid
let _ = sub::parse_subscription(&s.value()).map_err(|e| syn::Error::new(s.span(), format_args!("{e}")))?;

s.value()
} else {
return Err(lookahead.error());
};

Ok(sql)
}

/// Generates code for registering a row-level security `SQL` function.
///
/// A row-level security function takes a `SQL` query expression that is used to filter rows.
///
/// The query follows the same syntax as a subscription query.
///
/// **Example:**
///
/// ```rust,ignore
/// /// Players can only see what's in their chunk
/// spacetimedb::filter!("
/// SELECT * FROM LocationState WHERE chunk_index IN (
/// SELECT chunk_index FROM LocationState WHERE entity_id IN (
/// SELECT entity_id FROM UserState WHERE identity = @sender
/// )
/// )
/// ");
/// ```
///
/// **NOTE:** The `SQL` query expression is pre-parsed at compile time, but only check is a valid
/// subscription query *syntactically*, not that the query is valid when executed.
///
/// For example, it could refer to a non-existent table.
#[proc_macro]
pub fn filter(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let rls_sql = syn::parse_macro_input!(input with parse_sql);

let mut hasher = DefaultHasher::new();
rls_sql.hash(&mut hasher);
let rls_name = format_ident!("rls_{}", hasher.finish());

let register_rls_symbol = format!("__preinit__20_register_{rls_name}");

let generated_describe_function = quote! {
#[export_name = #register_rls_symbol]
extern "C" fn __register_rls() {
spacetimedb::rt::register_row_level_security::<#rls_name>()
}
};

let emission = quote! {
const _: () = {
#generated_describe_function
};
#[allow(non_camel_case_types)]
struct #rls_name { _never: ::core::convert::Infallible }
impl spacetimedb::rt::RowLevelSecurityInfo for #rls_name {
const SQL: &'static str = #rls_sql;
}
};

emission.into()
}
2 changes: 1 addition & 1 deletion crates/bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub use rng::StdbRng;
pub use sats::SpacetimeType;
#[doc(hidden)]
pub use spacetimedb_bindings_macro::__TableHelper;
pub use spacetimedb_bindings_macro::{duration, reducer, table};
pub use spacetimedb_bindings_macro::{duration, filter, reducer, table};
pub use spacetimedb_bindings_sys as sys;
pub use spacetimedb_lib;
pub use spacetimedb_lib::de::{Deserialize, DeserializeOwned};
Expand Down
13 changes: 13 additions & 0 deletions crates/bindings/src/rt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,12 @@ impl RepeaterArgs for (Timestamp,) {
}
}

/// A trait for types that can *describe* a row-level security policy.
pub trait RowLevelSecurityInfo {
/// The SQL expression for the row-level security policy.
const SQL: &'static str;
}

/// Registers into `DESCRIBERS` a function `f` to modify the module builder.
fn register_describer(f: fn(&mut ModuleBuilder)) {
DESCRIBERS.lock().unwrap().push(f)
Expand Down Expand Up @@ -352,6 +358,13 @@ pub fn register_reducer<'a, A: Args<'a>, I: ReducerInfo>(_: impl Reducer<'a, A>)
})
}

/// Registers a row-level security policy.
pub fn register_row_level_security<R: RowLevelSecurityInfo>() {
register_describer(|module| {
module.inner.add_row_level_security(R::SQL);
})
}

/// A builder for a module.
#[derive(Default)]
struct ModuleBuilder {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
source: crates/bindings/tests/deps.rs
expression: "cargo tree -p spacetimedb -f {lib} -e no-dev"
---
total crates: 62
total crates: 64
spacetimedb
├── bytemuck
├── derive_more
Expand Down Expand Up @@ -48,6 +48,15 @@ spacetimedb
│ │ ├── itertools
│ │ │ └── either
│ │ └── nohash_hasher
│ ├── spacetimedb_sql_parser
│ │ ├── derive_more (*)
│ │ ├── sqlparser
│ │ │ └── log
│ │ └── thiserror
│ │ └── thiserror_impl
│ │ ├── proc_macro2 (*)
│ │ ├── quote (*)
│ │ └── syn (*)
│ └── syn (*)
├── spacetimedb_bindings_sys
│ └── spacetimedb_primitives (*)
Expand Down Expand Up @@ -75,11 +84,7 @@ spacetimedb
│ │ │ └── equivalent
│ │ ├── nohash_hasher
│ │ ├── smallvec
│ │ └── thiserror
│ │ └── thiserror_impl
│ │ ├── proc_macro2 (*)
│ │ ├── quote (*)
│ │ └── syn (*)
│ │ └── thiserror (*)
│ ├── spacetimedb_primitives (*)
│ ├── spacetimedb_sats
│ │ ├── arrayvec
Expand Down
20 changes: 19 additions & 1 deletion crates/core/src/db/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ use super::datastore::locking_tx_datastore::MutTxId;
use super::relational_db::RelationalDB;
use crate::database_logger::SystemLogger;
use crate::execution_context::ExecutionContext;
use crate::sql::parser::RowLevelExpr;
use spacetimedb_data_structures::map::HashMap;
use spacetimedb_lib::db::auth::StTableType;
use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::AlgebraicValue;
use spacetimedb_primitives::ColSet;
use spacetimedb_schema::auto_migrate::{AutoMigratePlan, ManualMigratePlan, MigratePlan};
Expand All @@ -24,6 +26,7 @@ use std::sync::Arc;
pub fn update_database(
stdb: &RelationalDB,
tx: &mut MutTxId,
auth_ctx: AuthCtx,
plan: MigratePlan,
system_logger: &SystemLogger,
) -> anyhow::Result<()> {
Expand All @@ -44,7 +47,7 @@ pub fn update_database(

match plan {
MigratePlan::Manual(plan) => manual_migrate_database(stdb, tx, plan, system_logger, existing_tables),
MigratePlan::Auto(plan) => auto_migrate_database(stdb, tx, plan, system_logger, existing_tables),
MigratePlan::Auto(plan) => auto_migrate_database(stdb, tx, auth_ctx, plan, system_logger, existing_tables),
}
}

Expand All @@ -63,6 +66,7 @@ fn manual_migrate_database(
fn auto_migrate_database(
stdb: &RelationalDB,
tx: &mut MutTxId,
auth_ctx: AuthCtx,
plan: AutoMigratePlan,
system_logger: &SystemLogger,
existing_tables: Vec<Arc<TableSchema>>,
Expand Down Expand Up @@ -119,6 +123,7 @@ fn auto_migrate_database(

system_logger.info(&format!("Creating table `{}`", table_name));
log::info!("Creating table `{}`", table_name);

stdb.create_table(tx, table_schema)?;
}
spacetimedb_schema::auto_migrate::AutoMigrateStep::AddIndex(index_name) => {
Expand Down Expand Up @@ -221,6 +226,19 @@ fn auto_migrate_database(
spacetimedb_schema::auto_migrate::AutoMigrateStep::RemoveSchedule(_) => {
anyhow::bail!("Removing schedules is not yet implemented");
}
spacetimedb_schema::auto_migrate::AutoMigrateStep::AddRowLevelSecurity(sql_rls) => {
system_logger.info(&format!("Adding row-level security `{sql_rls}`"));
log::info!("Adding row-level security `{sql_rls}`");
let rls = plan.new.lookup_expect(sql_rls);
let rls = RowLevelExpr::build_row_level_expr(stdb, tx, &auth_ctx, rls)?;

stdb.create_row_level_security(tx, rls.def)?;
}
spacetimedb_schema::auto_migrate::AutoMigrateStep::RemoveRowLevelSecurity(sql_rls) => {
system_logger.info(&format!("Removing-row level security `{sql_rls}`"));
log::info!("Removing row-level security `{sql_rls}`");
stdb.drop_row_level_security(tx, sql_rls.clone())?;
}
}
}

Expand Down
24 changes: 20 additions & 4 deletions crates/core/src/host/wasm_common/module_host_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ use spacetimedb_schema::schema::{Schema, TableSchema};
use std::sync::Arc;
use std::time::Duration;

use spacetimedb_lib::buffer::DecodeError;
use spacetimedb_lib::{bsatn, Address, RawModuleDef};

use super::instrumentation::CallTimes;
use crate::database_logger::SystemLogger;
use crate::db::datastore::locking_tx_datastore::MutTxId;
Expand All @@ -28,10 +25,14 @@ use crate::identity::Identity;
use crate::messages::control_db::HostType;
use crate::module_host_context::ModuleCreationContext;
use crate::replica_context::ReplicaContext;
use crate::sql::parser::RowLevelExpr;
use crate::subscription::module_subscription_actor::WriteConflict;
use crate::util::const_unwrap;
use crate::util::prometheus_handle::HistogramExt;
use crate::worker_metrics::WORKER_METRICS;
use spacetimedb_lib::buffer::DecodeError;
use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::{bsatn, Address, RawModuleDef};

use super::*;

Expand Down Expand Up @@ -263,6 +264,7 @@ impl<T: WasmInstance> ModuleInstance for WasmModuleInstance<T> {
let timestamp = Timestamp::now();
let stdb = &*self.replica_context().relational_db;
let ctx = ExecutionContext::internal(stdb.address());
let auth_ctx = AuthCtx::for_current(self.replica_context().database.owner_identity);
let tx = stdb.begin_mut_tx(IsolationLevel::Serializable);
let (tx, ()) = stdb
.with_auto_rollback(&ctx, tx, |tx| {
Expand All @@ -276,6 +278,19 @@ impl<T: WasmInstance> ModuleInstance for WasmModuleInstance<T> {
stdb.create_table(tx, schema)
.with_context(|| format!("failed to create table {table_name}"))?;
}
// Insert the late-bound row-level security expressions.
for rls in self.info.module_def.row_level_security() {
self.system_logger()
.info(&format!("Creating row level security `{}`", rls.sql));

let rls = RowLevelExpr::build_row_level_expr(stdb, tx, &auth_ctx, rls)
.with_context(|| format!("failed to create row-level security: `{}`", rls.sql))?;
let table_id = rls.def.table_id;
let sql = rls.def.sql.clone();
stdb.create_row_level_security(tx, rls.def).with_context(|| {
format!("failed to create row-level security for table `{table_id}`: `{sql}`",)
})?;
}

stdb.set_initialized(tx, HostType::Wasm, program)?;

Expand Down Expand Up @@ -335,7 +350,8 @@ impl<T: WasmInstance> ModuleInstance for WasmModuleInstance<T> {
let (mut tx, _) = stdb.with_auto_rollback(&ctx, tx, |tx| stdb.update_program(tx, HostType::Wasm, program))?;
self.system_logger().info(&format!("Updated program to {program_hash}"));

let res = crate::db::update::update_database(stdb, &mut tx, plan, self.system_logger());
let auth_ctx = AuthCtx::for_current(self.replica_context().database.owner_identity);
let res = crate::db::update::update_database(stdb, &mut tx, auth_ctx, plan, self.system_logger());

match res {
Err(e) => {
Expand Down
1 change: 1 addition & 0 deletions crates/core/src/sql/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod ast;
pub mod compiler;
pub mod execute;
pub mod parser;
mod type_check;
35 changes: 35 additions & 0 deletions crates/core/src/sql/parser.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use crate::db::datastore::locking_tx_datastore::MutTxId;
use crate::db::relational_db::RelationalDB;
use crate::sql::ast::SchemaViewer;
use spacetimedb_expr::check::parse_and_type_sub;
use spacetimedb_expr::errors::TypingError;
use spacetimedb_expr::expr::RelExpr;
use spacetimedb_expr::ty::TyCtx;
use spacetimedb_lib::db::raw_def::v9::RawRowLevelSecurityDefV9;
use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_schema::schema::RowLevelSecuritySchema;

pub struct RowLevelExpr {
pub sql: RelExpr,
pub def: RowLevelSecuritySchema,
}

impl RowLevelExpr {
pub fn build_row_level_expr(
stdb: &RelationalDB,
tx: &mut MutTxId,
auth_ctx: &AuthCtx,
rls: &RawRowLevelSecurityDefV9,
) -> Result<Self, TypingError> {
let mut ctx = TyCtx::default();
let sql = parse_and_type_sub(&mut ctx, &rls.sql, &SchemaViewer::new(stdb, tx, auth_ctx))?;

Ok(Self {
def: RowLevelSecuritySchema {
table_id: sql.table_id(&mut ctx)?,
sql: rls.sql.clone(),
},
sql,
})
}
}
12 changes: 9 additions & 3 deletions crates/expr/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use spacetimedb_sql_parser::{ast::BinOp, parser::errors::SqlParseError};
use thiserror::Error;

use super::{
statement::InvalidVar,
ty::{InvalidTypeId, TypeWithCtx},
};
use spacetimedb_sql_parser::ast::BinOp;
use spacetimedb_sql_parser::parser::errors::SqlParseError;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum Unresolved {
Expand Down Expand Up @@ -134,6 +134,10 @@ impl UnexpectedType {
#[error("Duplicate name `{0}`")]
pub struct DuplicateName(pub String);

#[derive(Debug, Error)]
#[error("`filter!` does not support column projections; Must return table rows")]
pub struct FilterReturnType;

#[derive(Error, Debug)]
pub enum TypingError {
#[error(transparent)]
Expand Down Expand Up @@ -163,4 +167,6 @@ pub enum TypingError {
Wildcard(#[from] InvalidWildcard),
#[error(transparent)]
DuplicateName(#[from] DuplicateName),
#[error(transparent)]
FilterReturnType(#[from] FilterReturnType),
}
Loading

0 comments on commit 22241f9

Please sign in to comment.