diff --git a/crates/bench/benches/subscription.rs b/crates/bench/benches/subscription.rs index f23406a10f..3178547dd1 100644 --- a/crates/bench/benches/subscription.rs +++ b/crates/bench/benches/subscription.rs @@ -4,7 +4,7 @@ use spacetimedb::execution_context::Workload; use spacetimedb::host::module_host::DatabaseTableUpdate; use spacetimedb::identity::AuthCtx; use spacetimedb::messages::websocket::BsatnFormat; -use spacetimedb::subscription::query::compile_read_only_query; +use spacetimedb::subscription::query::compile_read_only_queryset; use spacetimedb::subscription::subscription::ExecutionSet; use spacetimedb::{db::relational_db::RelationalDB, messages::websocket::Compression}; use spacetimedb_bench::database::BenchDatabase as _; @@ -102,7 +102,7 @@ fn eval(c: &mut Criterion) { let bench_eval = |c: &mut Criterion, name, sql| { c.bench_function(name, |b| { let tx = raw.db.begin_tx(Workload::Update); - let query = compile_read_only_query(&raw.db, &AuthCtx::for_testing(), &tx, sql).unwrap(); + let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &tx, sql).unwrap(); let query: ExecutionSet = query.into(); b.iter(|| { @@ -141,8 +141,8 @@ fn eval(c: &mut Criterion) { let select_lhs = "select * from footprint"; let select_rhs = "select * from location"; let tx = &raw.db.begin_tx(Workload::Update); - let query_lhs = compile_read_only_query(&raw.db, &AuthCtx::for_testing(), tx, select_lhs).unwrap(); - let query_rhs = compile_read_only_query(&raw.db, &AuthCtx::for_testing(), tx, select_rhs).unwrap(); + let query_lhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, select_lhs).unwrap(); + let query_rhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, select_rhs).unwrap(); let query = ExecutionSet::from_iter(query_lhs.into_iter().chain(query_rhs)); let tx = &tx.into(); @@ -160,7 +160,7 @@ fn eval(c: &mut Criterion) { where location.chunk_index = {chunk_index}" ); let tx = &raw.db.begin_tx(Workload::Update); - let query = compile_read_only_query(&raw.db, &AuthCtx::for_testing(), tx, &join).unwrap(); + let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, &join).unwrap(); let query: ExecutionSet = query.into(); let tx = &tx.into(); diff --git a/crates/client-api-messages/src/websocket.rs b/crates/client-api-messages/src/websocket.rs index dcd8731354..17e36e5967 100644 --- a/crates/client-api-messages/src/websocket.rs +++ b/crates/client-api-messages/src/websocket.rs @@ -90,6 +90,10 @@ pub enum ClientMessage { Subscribe(Subscribe), /// Send a one-off SQL query without establishing a subscription. OneOffQuery(OneOffQuery), + /// Register a SQL query to to subscribe to updates. This does not affect other subscriptions. + SubscribeSingle(SubscribeSingle), + /// Remove a subscription to a SQL query that was added with SubscribeSingle. + Unsubscribe(Unsubscribe), } impl ClientMessage { @@ -106,8 +110,10 @@ impl ClientMessage { request_id, flags, }), - ClientMessage::Subscribe(x) => ClientMessage::Subscribe(x), ClientMessage::OneOffQuery(x) => ClientMessage::OneOffQuery(x), + ClientMessage::SubscribeSingle(x) => ClientMessage::SubscribeSingle(x), + ClientMessage::Unsubscribe(x) => ClientMessage::Unsubscribe(x), + ClientMessage::Subscribe(x) => ClientMessage::Subscribe(x), } } } @@ -162,6 +168,20 @@ impl_deserialize!([] CallReducerFlags, de => match de.deserialize_u8()? { x => Err(D::Error::custom(format_args!("invalid call reducer flag {x}"))), }); +/// An opaque id generated by the client to refer to a subscription. +/// This is used in Unsubscribe messages and errors. +#[derive(SpacetimeType, Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[sats(crate = spacetimedb_lib)] +pub struct QueryId { + pub id: u32, +} + +impl QueryId { + pub fn new(id: u32) -> Self { + Self { id } + } +} + /// Sent by client to database to register a set of queries, about which the client will /// receive `TransactionUpdate`s. /// @@ -184,6 +204,41 @@ pub struct Subscribe { pub request_id: u32, } +/// Sent by client to register a subscription to single query, for which the client should receive +/// receive relevant `TransactionUpdate`s. +/// +/// After issuing a `SubscribeSingle` message, the client will receive a single +/// `SubscribeApplied` message containing every current row which matches the query. Then, any +/// time a reducer updates the query's results, the client will receive a `TransactionUpdate` +/// containing the relevant updates. +/// +/// If a client subscribes to queries with overlapping results, the client will receive +/// multiple copies of rows that appear in multiple queries. +#[derive(SpacetimeType)] +#[sats(crate = spacetimedb_lib)] +pub struct SubscribeSingle { + /// A single SQL `SELECT` query to subscribe to. + pub query: Box, + /// An identifier for a client request. + pub request_id: u32, + + /// An identifier for this subscription, which should not be used for any other subscriptions on the same connection. + /// This is used to refer to this subscription in Unsubscribe messages from the client and errors sent from the server. + /// These only have meaning given a ConnectionId. + pub query_id: QueryId, +} + +/// Client request for removing a query from a subscription. +#[derive(SpacetimeType)] +#[sats(crate = spacetimedb_lib)] +pub struct Unsubscribe { + /// An identifier for a client request. + pub request_id: u32, + + /// The ID used in the corresponding `SubscribeSingle` message. + pub query_id: QueryId, +} + /// A one-off query submission. /// /// Query should be a "SELECT * FROM Table WHERE ...". Other types of queries will be rejected. @@ -213,6 +268,7 @@ pub const SERVER_MSG_COMPRESSION_TAG_GZIP: u8 = 2; #[sats(crate = spacetimedb_lib)] pub enum ServerMessage { /// Informs of changes to subscribed rows. + /// This will be removed when we switch to `SubscribeSingle`. InitialSubscription(InitialSubscription), /// Upon reducer run. TransactionUpdate(TransactionUpdate), @@ -222,6 +278,97 @@ pub enum ServerMessage { IdentityToken(IdentityToken), /// Return results to a one off SQL query. OneOffQueryResponse(OneOffQueryResponse), + /// Sent in response to a `SubscribeSingle` message. This contains the initial matching rows. + SubscribeApplied(SubscribeApplied), + /// Sent in response to an `Unsubscribe` message. This contains the matching rows. + UnsubscribeApplied(UnsubscribeApplied), + /// Communicate an error in the subscription lifecycle. + SubscriptionError(SubscriptionError), +} + +/// The matching rows of a subscription query. +#[derive(SpacetimeType)] +#[sats(crate = spacetimedb_lib)] +pub struct SubscribeRows { + /// The table ID of the query. + pub table_id: TableId, + /// The table name of the query. + pub table_name: Box, + /// The BSATN row values. + pub table_rows: TableUpdate, +} + +/// Response to [`Subscribe`] containing the initial matching rows. +#[derive(SpacetimeType)] +#[sats(crate = spacetimedb_lib)] +pub struct SubscribeApplied { + /// The request_id of the corresponding `SubscribeSingle` message. + pub request_id: u32, + /// The overall time between the server receiving a request and sending the response. + pub total_host_execution_duration_micros: u64, + /// An identifier for the subscribed query sent by the client. + pub query_id: QueryId, + /// The matching rows for this query. + pub rows: SubscribeRows, +} + +/// Server response to a client [`Unsubscribe`] request. +#[derive(SpacetimeType)] +#[sats(crate = spacetimedb_lib)] +pub struct UnsubscribeApplied { + /// Provided by the client via the `Subscribe` message. + /// TODO: switch to subscription id? + pub request_id: u32, + /// The overall time between the server receiving a request and sending the response. + pub total_host_execution_duration_micros: u64, + /// The ID included in the `SubscribeApplied` and `Unsubscribe` messages. + pub query_id: QueryId, + /// The matching rows for this query. + /// Note, this makes unsubscribing potentially very expensive. + /// To remove this in the future, we would need to send query_ids with rows in transaction updates, + /// and we would need clients to track which rows exist in which queries. + pub rows: SubscribeRows, +} + +/// Server response to an error at any point of the subscription lifecycle. +/// If this error doesn't have a request_id, the client should drop all subscriptions. +#[derive(SpacetimeType)] +#[sats(crate = spacetimedb_lib)] +pub struct SubscriptionError { + /// The overall time between the server receiving a request and sending the response. + pub total_host_execution_duration_micros: u64, + /// Provided by the client via a [`Subscribe`] or [`Unsubscribe`] message. + /// [`None`] if this occurred as the result of a [`TransactionUpdate`]. + pub request_id: Option, + /// The return table of the query in question. + /// The server is not required to set this field. + /// It has been added to avoid a breaking change post 1.0. + /// + /// If unset, an error results in the entire subscription being dropped. + /// Otherwise only queries of this table type must be dropped. + pub table_id: Option, + /// An error message describing the failure. + /// + /// This should reference specific fragments of the query where applicable, + /// but should not include the full text of the query, + /// as the client can retrieve that from the `request_id`. + /// + /// This is intended for diagnostic purposes. + /// It need not have a predictable/parseable format. + pub error: Box, +} + +/// Response to [`Subscribe`] containing the initial matching rows. +#[derive(SpacetimeType)] +#[sats(crate = spacetimedb_lib)] +pub struct SubscriptionUpdate { + /// A [`DatabaseUpdate`] containing only inserts, the rows which match the subscription queries. + pub database_update: DatabaseUpdate, + /// An identifier sent by the client in requests. + /// The server will include the same request_id in the response. + pub request_id: u32, + /// The overall time between the server receiving a request and sending the response. + pub total_host_execution_duration_micros: u64, } /// Response to [`Subscribe`] containing the initial matching rows. @@ -397,6 +544,15 @@ impl TableUpdate { } } + pub fn empty(table_id: TableId, table_name: Box) -> Self { + Self { + table_id, + table_name, + num_rows: 0, + updates: SmallVec::new(), + } + } + pub fn push(&mut self, (update, num_rows): (F::QueryUpdate, u64)) { self.updates.push(update); self.num_rows += num_rows; diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index fb646581a7..92621c631a 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -12,7 +12,9 @@ use crate::util::prometheus_handle::IntGaugeExt; use crate::worker_metrics::WORKER_METRICS; use derive_more::From; use futures::prelude::*; -use spacetimedb_client_api_messages::websocket::{CallReducerFlags, Compression, FormatSwitch}; +use spacetimedb_client_api_messages::websocket::{ + CallReducerFlags, Compression, FormatSwitch, SubscribeSingle, Unsubscribe, +}; use spacetimedb_lib::identity::RequestId; use tokio::sync::{mpsc, oneshot, watch}; use tokio::task::AbortHandle; @@ -283,12 +285,30 @@ impl ClientConnection { .await } + pub async fn subscribe_single(&self, subscription: SubscribeSingle, timer: Instant) -> Result<(), DBError> { + let me = self.clone(); + tokio::task::spawn_blocking(move || { + me.module + .subscriptions() + .add_subscription(me.sender, subscription, timer, None) + }) + .await + .unwrap() // TODO: is unwrapping right here? + } + + pub async fn unsubscribe(&self, request: Unsubscribe, timer: Instant) -> Result<(), DBError> { + let me = self.clone(); + tokio::task::spawn_blocking(move || me.module.subscriptions().remove_subscription(me.sender, request, timer)) + .await + .unwrap() // TODO: is unwrapping right here? + } + pub async fn subscribe(&self, subscription: Subscribe, timer: Instant) -> Result<(), DBError> { let me = self.clone(); tokio::task::spawn_blocking(move || { me.module .subscriptions() - .add_subscriber(me.sender, subscription, timer, None) + .add_legacy_subscriber(me.sender, subscription, timer, None) }) .await .unwrap() diff --git a/crates/core/src/client/message_handlers.rs b/crates/core/src/client/message_handlers.rs index af59b10160..c2e4312d42 100644 --- a/crates/core/src/client/message_handlers.rs +++ b/crates/core/src/client/message_handlers.rs @@ -79,6 +79,22 @@ pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Inst ) }) } + ClientMessage::SubscribeSingle(subscription) => { + let res = client.subscribe_single(subscription, timer).await; + WORKER_METRICS + .request_round_trip + .with_label_values(&WorkloadType::Subscribe, &address, "") + .observe(timer.elapsed().as_secs_f64()); + res.map_err(|e| (None, None, e.into())) + } + ClientMessage::Unsubscribe(request) => { + let res = client.unsubscribe(request, timer).await; + WORKER_METRICS + .request_round_trip + .with_label_values(&WorkloadType::Unsubscribe, &address, "") + .observe(timer.elapsed().as_secs_f64()); + res.map_err(|e| (None, None, e.into())) + } ClientMessage::Subscribe(subscription) => { let res = client.subscribe(subscription, timer).await; WORKER_METRICS diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index eacc330ebf..305758d605 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -11,6 +11,7 @@ use spacetimedb_client_api_messages::websocket::{ use spacetimedb_lib::identity::RequestId; use spacetimedb_lib::ser::serde::SerializeWrapper; use spacetimedb_lib::Address; +use spacetimedb_primitives::TableId; use spacetimedb_sats::bsatn; use spacetimedb_vm::relation::MemTable; use std::sync::Arc; @@ -66,6 +67,7 @@ pub enum SerializableMessage { Query(OneOffQueryResponseMessage), Identity(IdentityTokenMessage), Subscribe(SubscriptionUpdateMessage), + Subscription(SubscriptionMessage), TxUpdate(TransactionUpdateMessage), } @@ -74,6 +76,7 @@ impl SerializableMessage { match self { Self::Query(msg) => Some(msg.num_rows()), Self::Subscribe(msg) => Some(msg.num_rows()), + Self::Subscription(msg) => Some(msg.num_rows()), Self::TxUpdate(msg) => Some(msg.num_rows()), Self::Identity(_) => None, } @@ -83,6 +86,11 @@ impl SerializableMessage { match self { Self::Query(_) => Some(WorkloadType::Sql), Self::Subscribe(_) => Some(WorkloadType::Subscribe), + Self::Subscription(msg) => match &msg.result { + SubscriptionResult::Subscribe(_) => Some(WorkloadType::Subscribe), + SubscriptionResult::Unsubscribe(_) => Some(WorkloadType::Unsubscribe), + SubscriptionResult::Error(_) => None, + }, Self::TxUpdate(_) => Some(WorkloadType::Update), Self::Identity(_) => None, } @@ -97,6 +105,7 @@ impl ToProtocol for SerializableMessage { SerializableMessage::Identity(msg) => msg.to_protocol(protocol), SerializableMessage::Subscribe(msg) => msg.to_protocol(protocol), SerializableMessage::TxUpdate(msg) => msg.to_protocol(protocol), + SerializableMessage::Subscription(msg) => msg.to_protocol(protocol), } } } @@ -242,6 +251,156 @@ impl ToProtocol for SubscriptionUpdateMessage { } } +#[derive(Debug, Clone)] +pub struct SubscriptionRows { + pub table_id: TableId, + pub table_name: Box, + pub table_rows: FormatSwitch, ws::TableUpdate>, +} + +impl ToProtocol for SubscriptionRows { + type Encoded = FormatSwitch, ws::SubscribeRows>; + fn to_protocol(self, protocol: Protocol) -> Self::Encoded { + protocol.assert_matches_format_switch(&self.table_rows); + match self.table_rows { + FormatSwitch::Bsatn(table_rows) => FormatSwitch::Bsatn(ws::SubscribeRows { + table_id: self.table_id, + table_name: self.table_name, + table_rows, + }), + FormatSwitch::Json(table_rows) => FormatSwitch::Json(ws::SubscribeRows { + table_id: self.table_id, + table_name: self.table_name, + table_rows, + }), + } + } +} + +#[derive(Debug, Clone)] +pub struct SubscriptionError { + pub table_id: Option, + pub message: Box, +} + +#[derive(Debug, Clone)] +pub enum SubscriptionResult { + Subscribe(SubscriptionRows), + Unsubscribe(SubscriptionRows), + Error(SubscriptionError), +} + +#[derive(Debug, Clone)] +pub struct SubscriptionMessage { + pub timer: Option, + pub request_id: Option, + pub query_id: Option, + pub result: SubscriptionResult, +} + +fn num_rows_in(rows: &SubscriptionRows) -> usize { + match &rows.table_rows { + FormatSwitch::Bsatn(x) => x.num_rows(), + FormatSwitch::Json(x) => x.num_rows(), + } +} + +impl SubscriptionMessage { + fn num_rows(&self) -> usize { + match &self.result { + SubscriptionResult::Subscribe(x) => num_rows_in(x), + SubscriptionResult::Unsubscribe(x) => num_rows_in(x), + _ => 0, + } + } +} + +impl ToProtocol for SubscriptionMessage { + type Encoded = SwitchedServerMessage; + fn to_protocol(self, protocol: Protocol) -> Self::Encoded { + let request_id = self.request_id.unwrap_or(0); + let query_id = self.query_id.unwrap_or(ws::QueryId::new(0)); + let total_host_execution_duration_micros = self.timer.map_or(0, |t| t.elapsed().as_micros() as u64); + + match self.result { + SubscriptionResult::Subscribe(result) => { + protocol.assert_matches_format_switch(&result.table_rows); + match result.table_rows { + FormatSwitch::Bsatn(table_rows) => FormatSwitch::Bsatn( + ws::SubscribeApplied { + total_host_execution_duration_micros, + request_id, + query_id, + rows: ws::SubscribeRows { + table_id: result.table_id, + table_name: result.table_name, + table_rows, + }, + } + .into(), + ), + FormatSwitch::Json(table_rows) => FormatSwitch::Json( + ws::SubscribeApplied { + total_host_execution_duration_micros, + request_id, + query_id, + rows: ws::SubscribeRows { + table_id: result.table_id, + table_name: result.table_name, + table_rows, + }, + } + .into(), + ), + } + } + SubscriptionResult::Unsubscribe(result) => { + protocol.assert_matches_format_switch(&result.table_rows); + match result.table_rows { + FormatSwitch::Bsatn(table_rows) => FormatSwitch::Bsatn( + ws::UnsubscribeApplied { + total_host_execution_duration_micros, + request_id, + query_id, + rows: ws::SubscribeRows { + table_id: result.table_id, + table_name: result.table_name, + table_rows, + }, + } + .into(), + ), + FormatSwitch::Json(table_rows) => FormatSwitch::Json( + ws::UnsubscribeApplied { + total_host_execution_duration_micros, + request_id, + query_id, + rows: ws::SubscribeRows { + table_id: result.table_id, + table_name: result.table_name, + table_rows, + }, + } + .into(), + ), + } + } + SubscriptionResult::Error(error) => { + let msg = ws::SubscriptionError { + total_host_execution_duration_micros, + request_id: self.request_id, // Pass Option through + table_id: error.table_id, + error: error.message, + }; + match protocol { + Protocol::Binary => FormatSwitch::Bsatn(msg.into()), + Protocol::Text => FormatSwitch::Json(msg.into()), + } + } + } + } +} + #[derive(Debug)] pub struct OneOffQueryResponseMessage { pub message_id: Vec, diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index bc5661ef89..6eba969e53 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -103,6 +103,8 @@ pub enum SubscriptionError { SideEffect(Crud), #[error("Unsupported query on subscription: {0:?}")] Unsupported(String), + #[error("Subscribing to queries in one call is not supported")] + Multiple, } #[derive(Error, Debug)] diff --git a/crates/core/src/execution_context.rs b/crates/core/src/execution_context.rs index fa084404b2..31ef1ecc3b 100644 --- a/crates/core/src/execution_context.rs +++ b/crates/core/src/execution_context.rs @@ -100,6 +100,7 @@ pub enum Workload { Reducer(ReducerContext), Sql, Subscribe, + Unsubscribe, Update, Internal, } @@ -113,6 +114,7 @@ pub enum WorkloadType { Reducer, Sql, Subscribe, + Unsubscribe, Update, Internal, } @@ -125,6 +127,7 @@ impl From for WorkloadType { Workload::Reducer(_) => Self::Reducer, Workload::Sql => Self::Sql, Workload::Subscribe => Self::Subscribe, + Workload::Unsubscribe => Self::Unsubscribe, Workload::Update => Self::Update, Workload::Internal => Self::Internal, } @@ -156,6 +159,7 @@ impl ExecutionContext { Workload::Reducer(ctx) => Self::reducer(database_identity, ctx), Workload::Sql => Self::sql(database_identity), Workload::Subscribe => Self::subscribe(database_identity), + Workload::Unsubscribe => Self::unsubscribe(database_identity), Workload::Update => Self::incremental_update(database_identity), } } @@ -175,6 +179,11 @@ impl ExecutionContext { Self::new(database, None, WorkloadType::Subscribe) } + /// Returns an [ExecutionContext] for an unsubscribe call. + pub fn unsubscribe(database: Identity) -> Self { + Self::new(database, None, WorkloadType::Unsubscribe) + } + /// Returns an [ExecutionContext] for a subscription update. pub fn incremental_update(database: Identity) -> Self { Self::new(database, None, WorkloadType::Update) diff --git a/crates/core/src/subscription/execution_unit.rs b/crates/core/src/subscription/execution_unit.rs index f44df10503..222ad2834f 100644 --- a/crates/core/src/subscription/execution_unit.rs +++ b/crates/core/src/subscription/execution_unit.rs @@ -13,6 +13,7 @@ use spacetimedb_lib::db::error::AuthError; use spacetimedb_lib::relation::DbTable; use spacetimedb_lib::{Identity, ProductValue}; use spacetimedb_primitives::TableId; +use spacetimedb_sats::u256; use spacetimedb_vm::eval::IterRows; use spacetimedb_vm::expr::{AuthAccess, NoInMemUsed, Query, QueryExpr, SourceExpr, SourceId}; use spacetimedb_vm::rel_ops::RelOps; @@ -41,6 +42,12 @@ pub struct QueryHash { data: [u8; 32], } +impl From for u256 { + fn from(hash: QueryHash) -> Self { + u256::from_le_bytes(hash.data) + } +} + impl QueryHash { pub const NONE: Self = Self { data: [0; 32] }; diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 52b5ccbea5..efe2742918 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -1,9 +1,13 @@ use super::execution_unit::{ExecutionUnit, QueryHash}; use super::module_subscription_manager::SubscriptionManager; -use super::query::compile_read_only_query; +use super::query::{compile_read_only_query, compile_read_only_queryset}; use super::subscription::ExecutionSet; -use crate::client::messages::{SubscriptionUpdateMessage, TransactionUpdateMessage}; +use crate::client::messages::{ + SubscriptionError, SubscriptionMessage, SubscriptionResult, SubscriptionRows, SubscriptionUpdateMessage, + TransactionUpdateMessage, +}; use crate::client::{ClientActorId, ClientConnectionSender, Protocol}; +use crate::db::datastore::locking_tx_datastore::tx::TxId; use crate::db::datastore::system_tables::StVarTable; use crate::db::relational_db::{MutTx, RelationalDB, Tx}; use crate::error::DBError; @@ -14,7 +18,9 @@ use crate::sql::ast::SchemaViewer; use crate::vm::check_row_limit; use crate::worker_metrics::WORKER_METRICS; use parking_lot::RwLock; -use spacetimedb_client_api_messages::websocket::FormatSwitch; +use spacetimedb_client_api_messages::websocket::{ + BsatnFormat, FormatSwitch, JsonFormat, SubscribeSingle, TableUpdate, Unsubscribe, +}; use spacetimedb_expr::check::compile_sql_sub; use spacetimedb_expr::ty::TyCtx; use spacetimedb_lib::identity::AuthCtx; @@ -46,9 +52,169 @@ impl ModuleSubscriptions { } } + /// Run auth and row limit checks for a new subscriber, then compute the initial query results. + fn evaluate_initial_subscription( + &self, + sender: Arc, + query: Arc, + auth: AuthCtx, + tx: &TxId, + ) -> Result, TableUpdate>, DBError> { + query.check_auth(auth.owner, auth.caller).map_err(ErrorVm::Auth)?; + + check_row_limit( + &query, + &self.relational_db, + tx, + |query, tx| query.row_estimate(tx), + &auth, + )?; + + let slow_query_threshold = StVarTable::sub_limit(&self.relational_db, tx)?.map(Duration::from_millis); + Ok(match sender.config.protocol { + Protocol::Binary => FormatSwitch::Bsatn( + query + .eval( + &self.relational_db, + tx, + &query.sql, + slow_query_threshold, + sender.config.compression, + ) + .unwrap_or(TableUpdate::empty(query.return_table(), query.return_name())), + ), + Protocol::Text => FormatSwitch::Json( + query + .eval( + &self.relational_db, + tx, + &query.sql, + slow_query_threshold, + sender.config.compression, + ) + .unwrap_or(TableUpdate::empty(query.return_table(), query.return_name())), + ), + }) + } + + #[tracing::instrument(skip_all)] + pub fn add_subscription( + &self, + sender: Arc, + request: SubscribeSingle, + timer: Instant, + _assert: Option, + ) -> Result<(), DBError> { + let tx = scopeguard::guard(self.relational_db.begin_tx(Workload::Subscribe), |tx| { + self.relational_db.release_tx(tx); + }); + let auth = AuthCtx::new(self.owner_identity, sender.id.identity); + let guard = self.subscriptions.read(); + let query = super::query::WHITESPACE.replace_all(&request.query, " "); + let sql = query.trim(); + let hash = QueryHash::from_string(sql); + let query = if let Some(unit) = guard.query(&hash) { + unit + } else { + // NOTE: The following ensures compliance with the 1.0 sql api. + // Come 1.0, it will have replaced the current compilation stack. + compile_sql_sub( + &mut TyCtx::default(), + sql, + &SchemaViewer::new(&self.relational_db, &*tx, &auth), + )?; + + let compiled = compile_read_only_query(&self.relational_db, &auth, &tx, sql)?; + Arc::new(ExecutionUnit::new(compiled, hash)?) + }; + + drop(guard); + + let table_rows = self.evaluate_initial_subscription(sender.clone(), query.clone(), auth, &tx)?; + + // It acquires the subscription lock after `eval`, allowing `add_subscription` to run concurrently. + // This also makes it possible for `broadcast_event` to get scheduled before the subsequent part here + // but that should not pose an issue. + let mut subscriptions = self.subscriptions.write(); + subscriptions.add_subscription(sender.clone(), query.clone(), request.query_id)?; + + WORKER_METRICS + .subscription_queries + .with_label_values(&self.relational_db.database_identity()) + .set(subscriptions.num_unique_queries() as i64); + + #[cfg(test)] + if let Some(assert) = _assert { + assert(&tx); + } + + // NOTE: It is important to send the state in this thread because if you spawn a new + // thread it's possible for messages to get sent to the client out of order. If you do + // spawn in another thread messages will need to be buffered until the state is sent out + // on the wire + let _ = sender.send_message(SubscriptionMessage { + request_id: Some(request.request_id), + query_id: Some(request.query_id), + timer: Some(timer), + result: SubscriptionResult::Subscribe(SubscriptionRows { + table_id: query.return_table(), + table_name: query.return_name(), + table_rows, + }), + }); + Ok(()) + } + + pub fn remove_subscription( + &self, + sender: Arc, + request: Unsubscribe, + timer: Instant, + ) -> Result<(), DBError> { + let mut subscriptions = self.subscriptions.write(); + let query = match subscriptions.remove_subscription((sender.id.identity, sender.id.address), request.query_id) { + Ok(query) => query, + Err(error) => { + // Apparently we ignore errors sending messages. + let _ = sender.send_message(SubscriptionMessage { + request_id: Some(request.request_id), + query_id: None, + timer: Some(timer), + result: SubscriptionResult::Error(SubscriptionError { + table_id: None, + message: error.to_string().into(), + }), + }); + return Ok(()); + } + }; + let auth = AuthCtx::new(self.owner_identity, sender.id.identity); + let tx = scopeguard::guard(self.relational_db.begin_tx(Workload::Unsubscribe), |tx| { + self.relational_db.release_tx(tx); + }); + let table_rows = self.evaluate_initial_subscription(sender.clone(), query.clone(), auth, &tx)?; + + WORKER_METRICS + .subscription_queries + .with_label_values(&self.relational_db.database_identity()) + .set(subscriptions.num_unique_queries() as i64); + let _ = sender.send_message(SubscriptionMessage { + request_id: Some(request.request_id), + query_id: Some(request.query_id), + timer: Some(timer), + result: SubscriptionResult::Unsubscribe(SubscriptionRows { + table_id: query.return_table(), + table_name: query.return_name(), + table_rows, + }), + }); + Ok(()) + } + /// Add a subscriber to the module. NOTE: this function is blocking. + /// This is used for the legacy subscription API which uses a set of queries. #[tracing::instrument(skip_all)] - pub fn add_subscriber( + pub fn add_legacy_subscriber( &self, sender: Arc, subscription: Subscribe, @@ -94,7 +260,7 @@ impl ModuleSubscriptions { &SchemaViewer::new(&self.relational_db, &*tx, &auth), )?; - let mut compiled = compile_read_only_query(&self.relational_db, &auth, &tx, sql)?; + let mut compiled = compile_read_only_queryset(&self.relational_db, &auth, &tx, sql)?; // Note that no error path is needed here. // We know this vec only has a single element, // since `parse_and_type_sub` guarantees it. @@ -141,9 +307,8 @@ impl ModuleSubscriptions { // This also makes it possible for `broadcast_event` to get scheduled before the subsequent part here // but that should not pose an issue. let mut subscriptions = self.subscriptions.write(); - subscriptions.remove_subscription(&(sender.id.identity, sender.id.address)); - subscriptions.add_subscription(sender.clone(), execution_set.into_iter()); - let num_queries = subscriptions.num_queries(); + subscriptions.set_legacy_subscription(sender.clone(), execution_set.into_iter()); + let num_queries = subscriptions.num_unique_queries(); WORKER_METRICS .subscription_queries @@ -169,11 +334,11 @@ impl ModuleSubscriptions { pub fn remove_subscriber(&self, client_id: ClientActorId) { let mut subscriptions = self.subscriptions.write(); - subscriptions.remove_subscription(&(client_id.identity, client_id.address)); + subscriptions.remove_all_subscriptions(&(client_id.identity, client_id.address)); WORKER_METRICS .subscription_queries .with_label_values(&self.relational_db.database_identity()) - .set(subscriptions.num_queries() as i64); + .set(subscriptions.num_unique_queries() as i64); } /// Commit a transaction and broadcast its ModuleEvent to all interested subscribers. @@ -261,7 +426,7 @@ mod tests { query_strings: [sql.into()].into(), request_id: 0, }; - module_subscriptions.add_subscriber(sender, subscribe, Instant::now(), assert) + module_subscriptions.add_legacy_subscriber(sender, subscribe, Instant::now(), assert) } /// Asserts that a subscription holds a tx handle for the entire length of its evaluation. diff --git a/crates/core/src/subscription/module_subscription_manager.rs b/crates/core/src/subscription/module_subscription_manager.rs index 3639a06589..c9a8c19c41 100644 --- a/crates/core/src/subscription/module_subscription_manager.rs +++ b/crates/core/src/subscription/module_subscription_manager.rs @@ -2,12 +2,14 @@ use super::execution_unit::{ExecutionUnit, QueryHash}; use crate::client::messages::{SubscriptionUpdateMessage, TransactionUpdateMessage}; use crate::client::{ClientConnectionSender, Protocol}; use crate::db::relational_db::{RelationalDB, Tx}; +use crate::error::DBError; use crate::host::module_host::{DatabaseTableUpdate, ModuleEvent, UpdatesRelValue}; use crate::messages::websocket::{self as ws, TableUpdate}; use arrayvec::ArrayVec; +use hashbrown::hash_map::OccupiedError; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use spacetimedb_client_api_messages::websocket::{ - BsatnFormat, CompressableQueryUpdate, FormatSwitch, JsonFormat, QueryUpdate, WebsocketFormat, + BsatnFormat, CompressableQueryUpdate, FormatSwitch, JsonFormat, QueryId, QueryUpdate, WebsocketFormat, }; use spacetimedb_data_structures::map::{Entry, HashCollectionExt, HashMap, HashSet, IntMap}; use spacetimedb_lib::{Address, Identity}; @@ -18,11 +20,63 @@ use std::time::Duration; /// Clients are uniquely identified by their Identity and Address. /// Identity is insufficient because different Addresses can use the same Identity. /// TODO: Determine if Address is sufficient for uniquely identifying a client. -type Id = (Identity, Address); +type ClientId = (Identity, Address); type Query = Arc; type Client = Arc; type SwitchedDbUpdate = FormatSwitch, ws::DatabaseUpdate>; +/// ClientQueryId is an identifier for a query set by the client. +type ClientQueryId = QueryId; +/// SubscriptionId is a globally unique identifier for a subscription. +type SubscriptionId = (ClientId, ClientQueryId); + +/// For each client, we hold a handle for sending messages, and we track the queries they are subscribed to. +#[derive(Debug)] +struct ClientInfo { + outbound_ref: Client, + subscriptions: HashMap, + // This should be removed when we migrate to SubscribeSingle. + legacy_subscriptions: HashSet, +} + +impl ClientInfo { + fn new(outbound_ref: Client) -> Self { + Self { + outbound_ref, + subscriptions: HashMap::default(), + legacy_subscriptions: HashSet::default(), + } + } +} + +/// For each query that has subscribers, we track a set of legacy subscribers and individual subscriptions. +#[derive(Debug)] +struct QueryState { + query: Query, + legacy_subscribers: HashSet, + subscriptions: HashSet, +} + +impl QueryState { + fn new(query: Query) -> Self { + Self { + query, + legacy_subscribers: HashSet::default(), + subscriptions: HashSet::default(), + } + } + fn has_subscribers(&self) -> bool { + !self.subscriptions.is_empty() || !self.legacy_subscribers.is_empty() + } + + // This returns all of the clients listening to a query. If a client has multiple subscriptions for this query, it will appear twice. + fn all_clients(&self) -> impl Iterator { + let legacy_iter = self.legacy_subscribers.iter(); + let subscriptions_iter = self.subscriptions.iter().map(|(client_id, _)| client_id); + legacy_iter.chain(subscriptions_iter) + } +} + /// Responsible for the efficient evaluation of subscriptions. /// It performs basic multi-query optimization, /// in that if a query has N subscribers, @@ -30,26 +84,26 @@ type SwitchedDbUpdate = FormatSwitch, ws::Databa /// with the results copied to the N receivers. #[derive(Debug, Default)] pub struct SubscriptionManager { - // Subscriber identities and their client connections. - clients: HashMap, + // State for each client. + clients: HashMap, + // Queries for which there is at least one subscriber. - queries: HashMap, - // The subscribers for each query. - subscribers: HashMap>, + queries: HashMap, + // Inverted index from tables to queries that read from them. tables: IntMap>, } impl SubscriptionManager { - pub fn client(&self, id: &Id) -> Client { - self.clients[id].clone() + pub fn client(&self, id: &ClientId) -> Client { + self.clients[id].outbound_ref.clone() } pub fn query(&self, hash: &QueryHash) -> Option { - self.queries.get(hash).cloned() + self.queries.get(hash).map(|state| state.query.clone()) } - pub fn num_queries(&self) -> usize { + pub fn num_unique_queries(&self) -> usize { self.queries.len() } @@ -59,8 +113,10 @@ impl SubscriptionManager { } #[cfg(test)] - fn contains_subscription(&self, subscriber: &Id, query: &QueryHash) -> bool { - self.subscribers.get(query).is_some_and(|ids| ids.contains(subscriber)) + fn contains_legacy_subscription(&self, subscriber: &ClientId, query: &QueryHash) -> bool { + self.queries + .get(query) + .is_some_and(|state| state.legacy_subscribers.contains(subscriber)) } #[cfg(test)] @@ -68,19 +124,158 @@ impl SubscriptionManager { self.tables.get(table).is_some_and(|queries| queries.contains(query)) } + fn remove_legacy_subscriptions(&mut self, client: &ClientId) { + if let Some(ci) = self.clients.get_mut(client) { + let mut queries_to_remove = Vec::new(); + for query_hash in ci.legacy_subscriptions.iter() { + let query_state = self.queries.get_mut(query_hash); + if query_state.is_none() { + tracing::warn!("Query state not found for query hash: {:?}", query_hash); + continue; + } + let query_state = query_state.unwrap(); + query_state.legacy_subscribers.remove(client); + if !query_state.has_subscribers() { + SubscriptionManager::remove_table_query( + &mut self.tables, + query_state.query.return_table(), + query_hash, + ); + SubscriptionManager::remove_table_query( + &mut self.tables, + query_state.query.filter_table(), + query_hash, + ); + queries_to_remove.push(*query_hash); + } + } + ci.legacy_subscriptions.clear(); + for query_hash in queries_to_remove { + self.queries.remove(&query_hash); + } + } + } + + pub fn remove_subscription(&mut self, client_id: ClientId, query_id: ClientQueryId) -> Result { + let subscription_id = (client_id, query_id); + let ci = if let Some(ci) = self.clients.get_mut(&client_id) { + ci + } else { + return Err(anyhow::anyhow!("Client not found: {:?}", client_id).into()); + }; + + let query_hash = if let Some(query_hash) = ci.subscriptions.remove(&subscription_id) { + query_hash + } else { + return Err(anyhow::anyhow!("Subscription not found: {:?}", subscription_id).into()); + }; + let query_state = match self.queries.get_mut(&query_hash) { + Some(query_state) => query_state, + None => return Err(anyhow::anyhow!("Query state not found for query hash: {:?}", query_hash).into()), + }; + let query = query_state.query.clone(); + // Check if the query has any subscribers left. + let should_remove = { + query_state.subscriptions.remove(&subscription_id); + if !query_state.has_subscribers() { + SubscriptionManager::remove_table_query( + &mut self.tables, + query_state.query.return_table(), + &query_hash, + ); + SubscriptionManager::remove_table_query( + &mut self.tables, + query_state.query.filter_table(), + &query_hash, + ); + true + } else { + false + } + }; + if should_remove { + self.queries.remove(&query_hash); + } + Ok(query) + } + + /// Adds a single subscription for a client. + pub fn add_subscription(&mut self, client: Client, query: Query, query_id: ClientQueryId) -> Result<(), DBError> { + let client_id = (client.id.identity, client.id.address); + let ci = self + .clients + .entry(client_id) + .or_insert_with(|| ClientInfo::new(client.clone())); + let subscription_id = (client_id, query_id); + let hash = query.hash(); + + if let Err(OccupiedError { .. }) = ci.subscriptions.try_insert(subscription_id, hash) { + return Err(anyhow::anyhow!( + "Subscription with id {:?} already exists for client: {:?}", + query_id, + client_id + ) + .into()); + } + + let query_state = self + .queries + .entry(hash) + .or_insert_with(|| QueryState::new(query.clone())); + + // If this is new, we need to update the table to query mapping. + if !query_state.has_subscribers() { + self.tables.entry(query.return_table()).or_default().insert(hash); + self.tables.entry(query.filter_table()).or_default().insert(hash); + query_state.subscriptions.insert(subscription_id); + } + + query_state.subscriptions.insert(subscription_id); + + Ok(()) + } + /// Adds a client and its queries to the subscription manager. + /// Sets up the set of subscriptions for the client, replacing any existing legacy subscriptions. + /// /// If a query is not already indexed, /// its table ids added to the inverted index. - #[tracing::instrument(skip_all)] - pub fn add_subscription(&mut self, client: Client, queries: impl IntoIterator) { - let id = (client.id.identity, client.id.address); - self.clients.insert(id, client); + // #[tracing::instrument(skip_all)] + pub fn set_legacy_subscription(&mut self, client: Client, queries: impl IntoIterator) { + // TODO: Remove existing subscriptions. + let client_id = (client.id.identity, client.id.address); + // First, remove any existing legacy subscriptions. + self.remove_legacy_subscriptions(&client_id); + + // Now, add the new subscriptions. + let ci = self + .clients + .entry(client_id) + .or_insert_with(|| ClientInfo::new(client.clone())); for unit in queries { let hash = unit.hash(); + ci.legacy_subscriptions.insert(hash); + let query_state = self + .queries + .entry(hash) + .or_insert_with(|| QueryState::new(unit.clone())); self.tables.entry(unit.return_table()).or_default().insert(hash); self.tables.entry(unit.filter_table()).or_default().insert(hash); - self.subscribers.entry(hash).or_default().insert(id); - self.queries.insert(hash, unit); + query_state.legacy_subscribers.insert(client_id); + // self.subscribers.entry(hash).or_default().insert(id); + // self.queries.insert(hash, unit); + } + } + + // Remove `hash` from the set of queries for `table_id`. + // When the table has no queries, cleanup the map entry altogether. + // This takes a ref to the table map instead of `self` to avoid borrowing issues. + fn remove_table_query(tables: &mut IntMap>, table_id: TableId, hash: &QueryHash) { + if let Entry::Occupied(mut entry) = tables.entry(table_id) { + let hashes = entry.get_mut(); + if hashes.remove(hash) && hashes.is_empty() { + entry.remove(); + } } } @@ -88,29 +283,33 @@ impl SubscriptionManager { /// If a query no longer has any subscribers, /// it is removed from the index along with its table ids. #[tracing::instrument(skip_all)] - pub fn remove_subscription(&mut self, client: &Id) { - // Remove `hash` from the set of queries for `table_id`. - // When the table has no queries, cleanup the map entry altogether. - let mut remove_table_query = |table_id: TableId, hash: &QueryHash| { - if let Entry::Occupied(mut entry) = self.tables.entry(table_id) { - let hashes = entry.get_mut(); - if hashes.remove(hash) && hashes.is_empty() { - entry.remove(); - } + pub fn remove_all_subscriptions(&mut self, client: &ClientId) { + self.remove_legacy_subscriptions(client); + let client_info = self.clients.get(client); + if client_info.is_none() { + return; + } + let client_info = client_info.unwrap(); + debug_assert!(client_info.legacy_subscriptions.is_empty()); + let mut queries_to_remove = Vec::new(); + client_info.subscriptions.iter().for_each(|(sub_id, query_hash)| { + let query_state = self.queries.get_mut(query_hash); + if query_state.is_none() { + tracing::warn!("Query state not found for query hash: {:?}", query_hash); + return; } - }; - - self.clients.remove(client); - self.subscribers.retain(|hash, ids| { - ids.remove(client); - if ids.is_empty() { - if let Some(query) = self.queries.remove(hash) { - remove_table_query(query.return_table(), hash); - remove_table_query(query.filter_table(), hash); - } + let query_state = query_state.unwrap(); + query_state.subscriptions.remove(sub_id); + // This could happen twice for the same hash if a client has a duplicate, but that's fine. It is idepotent. + if !query_state.has_subscribers() { + queries_to_remove.push(*query_hash); + SubscriptionManager::remove_table_query(&mut self.tables, query_state.query.return_table(), query_hash); + SubscriptionManager::remove_table_query(&mut self.tables, query_state.query.filter_table(), query_hash); } - !ids.is_empty() }); + for query_hash in queries_to_remove { + self.queries.remove(&query_hash); + } } /// This method takes a set of delta tables, @@ -148,7 +347,7 @@ impl SubscriptionManager { let mut eval = units .par_iter() .filter_map(|(&hash, tables)| { - let unit = self.queries.get(hash)?; + let unit = &self.queries.get(hash)?.query; unit.eval_incr(db, tx, &unit.sql, tables.iter().copied(), slow_query_threshold) .map(|table| (hash, table)) }) @@ -177,14 +376,22 @@ impl SubscriptionManager { .clone() } - self.subscribers.get(hash).into_iter().flatten().map(move |id| { - let client = &*self.clients[id]; - let update = match client.config.protocol { - Protocol::Binary => Bsatn(memo_encode::(&delta.updates, client, &mut ops_bin)), - Protocol::Text => Json(memo_encode::(&delta.updates, client, &mut ops_json)), - }; - (id, table_id, table_name.clone(), update) - }) + self.queries + .get(hash) + .into_iter() + .flat_map(|query| query.all_clients()) + .map(move |id| { + let client = &self.clients[id].outbound_ref; + let update = match client.config.protocol { + Protocol::Binary => { + Bsatn(memo_encode::(&delta.updates, client, &mut ops_bin)) + } + Protocol::Text => { + Json(memo_encode::(&delta.updates, client, &mut ops_json)) + } + }; + (id, table_id, table_name.clone(), update) + }) }) .collect::>() .into_iter() @@ -194,7 +401,7 @@ impl SubscriptionManager { // so their `TableUpdate` will contain either JSON (`Protocol::Text`) // or BSATN (`Protocol::Binary`). .fold( - HashMap::<(&Id, TableId), FormatSwitch, TableUpdate<_>>>::new(), + HashMap::<(&ClientId, TableId), FormatSwitch, TableUpdate<_>>>::new(), |mut tables, (id, table_id, table_name, update)| { match tables.entry((id, table_id)) { Entry::Occupied(mut entry) => match entry.get_mut().zip_mut(update) { @@ -214,7 +421,7 @@ impl SubscriptionManager { // So before sending the updates to each client, // we must stitch together the `TableUpdate*`s into an aggregated list. .fold( - HashMap::<&Id, SwitchedDbUpdate>::new(), + HashMap::<&ClientId, SwitchedDbUpdate>::new(), |mut updates, ((id, _), update)| { let entry = updates.entry(id); let entry = entry.or_insert_with(|| match &update { @@ -275,12 +482,14 @@ mod tests { use std::{sync::Arc, time::Duration}; use spacetimedb_client_api_messages::timestamp::Timestamp; + use spacetimedb_client_api_messages::websocket::QueryId; use spacetimedb_lib::{error::ResultTest, identity::AuthCtx, Address, AlgebraicType, Identity}; use spacetimedb_primitives::TableId; use spacetimedb_vm::expr::CrudExpr; use super::SubscriptionManager; use crate::execution_context::Workload; + use crate::subscription::module_subscription_manager::ClientQueryId; use crate::{ client::{ClientActorId, ClientConfig, ClientConnectionSender, ClientName}, db::relational_db::{tests_utils::TestDB, RelationalDB}, @@ -331,7 +540,7 @@ mod tests { } #[test] - fn test_subscribe() -> ResultTest<()> { + fn test_subscribe_legacy() -> ResultTest<()> { let db = TestDB::durable()?; let table_id = create_table(&db, "T")?; @@ -343,15 +552,122 @@ mod tests { let client = Arc::new(client(0)); let mut subscriptions = SubscriptionManager::default(); - subscriptions.add_subscription(client, [plan]); + subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]); assert!(subscriptions.contains_query(&hash)); - assert!(subscriptions.contains_subscription(&id, &hash)); + assert!(subscriptions.contains_legacy_subscription(&id, &hash)); assert!(subscriptions.query_reads_from_table(&hash, &table_id)); Ok(()) } + #[test] + fn test_subscribe_single_adds_table_mapping() -> ResultTest<()> { + let db = TestDB::durable()?; + + let table_id = create_table(&db, "T")?; + let sql = "select * from T"; + let plan = compile_plan(&db, sql)?; + let hash = plan.hash(); + + let client = Arc::new(client(0)); + + let query_id: ClientQueryId = QueryId::new(1); + let mut subscriptions = SubscriptionManager::default(); + subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?; + assert!(subscriptions.query_reads_from_table(&hash, &table_id)); + + Ok(()) + } + + #[test] + fn test_unsubscribe_from_the_only_subscription() -> ResultTest<()> { + let db = TestDB::durable()?; + + let table_id = create_table(&db, "T")?; + let sql = "select * from T"; + let plan = compile_plan(&db, sql)?; + let hash = plan.hash(); + + let client = Arc::new(client(0)); + + let query_id: ClientQueryId = QueryId::new(1); + let mut subscriptions = SubscriptionManager::default(); + subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?; + assert!(subscriptions.query_reads_from_table(&hash, &table_id)); + + let client_id = (client.id.identity, client.id.address); + subscriptions.remove_subscription(client_id, query_id)?; + assert!(!subscriptions.query_reads_from_table(&hash, &table_id)); + + Ok(()) + } + + #[test] + fn test_unsubscribe_with_unknown_query_id_fails() -> ResultTest<()> { + let db = TestDB::durable()?; + + create_table(&db, "T")?; + let sql = "select * from T"; + let plan = compile_plan(&db, sql)?; + + let client = Arc::new(client(0)); + + let query_id: ClientQueryId = QueryId::new(1); + let mut subscriptions = SubscriptionManager::default(); + subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?; + + let client_id = (client.id.identity, client.id.address); + assert!(subscriptions.remove_subscription(client_id, QueryId::new(2)).is_err()); + + Ok(()) + } + + #[test] + fn test_subscribe_and_unsubscribe_with_duplicate_queries() -> ResultTest<()> { + let db = TestDB::durable()?; + + let table_id = create_table(&db, "T")?; + let sql = "select * from T"; + let plan = compile_plan(&db, sql)?; + let hash = plan.hash(); + + let client = Arc::new(client(0)); + + let query_id: ClientQueryId = QueryId::new(1); + let mut subscriptions = SubscriptionManager::default(); + subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?; + subscriptions.add_subscription(client.clone(), plan.clone(), QueryId::new(2))?; + + let client_id = (client.id.identity, client.id.address); + subscriptions.remove_subscription(client_id, query_id)?; + + assert!(subscriptions.query_reads_from_table(&hash, &table_id)); + + Ok(()) + } + + #[test] + fn test_subscribe_fails_with_duplicate_request_id() -> ResultTest<()> { + let db = TestDB::durable()?; + + create_table(&db, "T")?; + let sql = "select * from T"; + let plan = compile_plan(&db, sql)?; + + let client = Arc::new(client(0)); + + let query_id: ClientQueryId = QueryId::new(1); + let mut subscriptions = SubscriptionManager::default(); + subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?; + + assert!(subscriptions + .add_subscription(client.clone(), plan.clone(), query_id) + .is_err()); + + Ok(()) + } + #[test] fn test_unsubscribe() -> ResultTest<()> { let db = TestDB::durable()?; @@ -365,11 +681,11 @@ mod tests { let client = Arc::new(client(0)); let mut subscriptions = SubscriptionManager::default(); - subscriptions.add_subscription(client, [plan]); - subscriptions.remove_subscription(&id); + subscriptions.set_legacy_subscription(client, [plan]); + subscriptions.remove_all_subscriptions(&id); assert!(!subscriptions.contains_query(&hash)); - assert!(!subscriptions.contains_subscription(&id, &hash)); + assert!(!subscriptions.contains_legacy_subscription(&id, &hash)); assert!(!subscriptions.query_reads_from_table(&hash, &table_id)); Ok(()) @@ -388,17 +704,17 @@ mod tests { let client = Arc::new(client(0)); let mut subscriptions = SubscriptionManager::default(); - subscriptions.add_subscription(client.clone(), [plan.clone()]); - subscriptions.add_subscription(client.clone(), [plan.clone()]); + subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]); + subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]); assert!(subscriptions.contains_query(&hash)); - assert!(subscriptions.contains_subscription(&id, &hash)); + assert!(subscriptions.contains_legacy_subscription(&id, &hash)); assert!(subscriptions.query_reads_from_table(&hash, &table_id)); - subscriptions.remove_subscription(&id); + subscriptions.remove_all_subscriptions(&id); assert!(!subscriptions.contains_query(&hash)); - assert!(!subscriptions.contains_subscription(&id, &hash)); + assert!(!subscriptions.contains_legacy_subscription(&id, &hash)); assert!(!subscriptions.query_reads_from_table(&hash, &table_id)); Ok(()) @@ -420,21 +736,21 @@ mod tests { let client1 = Arc::new(client(1)); let mut subscriptions = SubscriptionManager::default(); - subscriptions.add_subscription(client0, [plan.clone()]); - subscriptions.add_subscription(client1, [plan.clone()]); + subscriptions.set_legacy_subscription(client0, [plan.clone()]); + subscriptions.set_legacy_subscription(client1, [plan.clone()]); assert!(subscriptions.contains_query(&hash)); - assert!(subscriptions.contains_subscription(&id0, &hash)); - assert!(subscriptions.contains_subscription(&id1, &hash)); + assert!(subscriptions.contains_legacy_subscription(&id0, &hash)); + assert!(subscriptions.contains_legacy_subscription(&id1, &hash)); assert!(subscriptions.query_reads_from_table(&hash, &table_id)); - subscriptions.remove_subscription(&id0); + subscriptions.remove_all_subscriptions(&id0); assert!(subscriptions.contains_query(&hash)); - assert!(subscriptions.contains_subscription(&id1, &hash)); + assert!(subscriptions.contains_legacy_subscription(&id1, &hash)); assert!(subscriptions.query_reads_from_table(&hash, &table_id)); - assert!(!subscriptions.contains_subscription(&id0, &hash)); + assert!(!subscriptions.contains_legacy_subscription(&id0, &hash)); Ok(()) } @@ -465,18 +781,18 @@ mod tests { let client1 = Arc::new(client(1)); let mut subscriptions = SubscriptionManager::default(); - subscriptions.add_subscription(client0, [plan_scan.clone(), plan_select0.clone()]); - subscriptions.add_subscription(client1, [plan_scan.clone(), plan_select1.clone()]); + subscriptions.set_legacy_subscription(client0, [plan_scan.clone(), plan_select0.clone()]); + subscriptions.set_legacy_subscription(client1, [plan_scan.clone(), plan_select1.clone()]); assert!(subscriptions.contains_query(&hash_scan)); assert!(subscriptions.contains_query(&hash_select0)); assert!(subscriptions.contains_query(&hash_select1)); - assert!(subscriptions.contains_subscription(&id0, &hash_scan)); - assert!(subscriptions.contains_subscription(&id0, &hash_select0)); + assert!(subscriptions.contains_legacy_subscription(&id0, &hash_scan)); + assert!(subscriptions.contains_legacy_subscription(&id0, &hash_select0)); - assert!(subscriptions.contains_subscription(&id1, &hash_scan)); - assert!(subscriptions.contains_subscription(&id1, &hash_select1)); + assert!(subscriptions.contains_legacy_subscription(&id1, &hash_scan)); + assert!(subscriptions.contains_legacy_subscription(&id1, &hash_select1)); assert!(subscriptions.query_reads_from_table(&hash_scan, &t)); assert!(subscriptions.query_reads_from_table(&hash_select0, &t)); @@ -486,17 +802,17 @@ mod tests { assert!(!subscriptions.query_reads_from_table(&hash_select0, &s)); assert!(!subscriptions.query_reads_from_table(&hash_select1, &t)); - subscriptions.remove_subscription(&id0); + subscriptions.remove_all_subscriptions(&id0); assert!(subscriptions.contains_query(&hash_scan)); assert!(subscriptions.contains_query(&hash_select1)); assert!(!subscriptions.contains_query(&hash_select0)); - assert!(subscriptions.contains_subscription(&id1, &hash_scan)); - assert!(subscriptions.contains_subscription(&id1, &hash_select1)); + assert!(subscriptions.contains_legacy_subscription(&id1, &hash_scan)); + assert!(subscriptions.contains_legacy_subscription(&id1, &hash_select1)); - assert!(!subscriptions.contains_subscription(&id0, &hash_scan)); - assert!(!subscriptions.contains_subscription(&id0, &hash_select0)); + assert!(!subscriptions.contains_legacy_subscription(&id0, &hash_scan)); + assert!(!subscriptions.contains_legacy_subscription(&id0, &hash_select0)); assert!(subscriptions.query_reads_from_table(&hash_scan, &t)); assert!(subscriptions.query_reads_from_table(&hash_select1, &s)); diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index db75d33722..69cdc3b718 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -10,6 +10,7 @@ use spacetimedb_vm::expr::{self, Crud, CrudExpr, QueryExpr}; pub(crate) static WHITESPACE: Lazy = Lazy::new(|| Regex::new(r"\s+").unwrap()); pub const SUBSCRIBE_TO_ALL_QUERY: &str = "SELECT * FROM *"; +// TODO: Remove this after the SubscribeSingle migration. // TODO: It's semantically wrong to `SUBSCRIBE_TO_ALL_QUERY` // as it can only return back the changes valid for the tables in scope *right now* // instead of **continuously updating** the db changes @@ -20,7 +21,7 @@ pub const SUBSCRIBE_TO_ALL_QUERY: &str = "SELECT * FROM *"; /// /// This is necessary when merging multiple SQL queries into a single query set, /// as in [`crate::subscription::module_subscription_actor::ModuleSubscriptions::add_subscriber`]. -pub fn compile_read_only_query( +pub fn compile_read_only_queryset( relational_db: &RelationalDB, auth: &AuthCtx, tx: &Tx, @@ -61,6 +62,45 @@ pub fn compile_read_only_query( } } +/// Compile a string into a single read-only query. +/// This returns an error if the string has multiple queries or mutations. +pub fn compile_read_only_query( + relational_db: &RelationalDB, + auth: &AuthCtx, + tx: &Tx, + input: &str, +) -> Result { + let input = input.trim(); + if input.is_empty() { + return Err(SubscriptionError::Empty.into()); + } + + // Remove redundant whitespace, and in particular newlines, for debug info. + let input = WHITESPACE.replace_all(input, " "); + + let single: CrudExpr = { + let mut compiled = compile_sql(relational_db, auth, tx, &input)?; + // Return an error if this doesn't produce exactly one query. + let first_query = compiled.pop(); + let other_queries = compiled.len(); + match (first_query, other_queries) { + (None, _) => return Err(SubscriptionError::Empty.into()), + (Some(q), 0) => q, + _ => return Err(SubscriptionError::Multiple.into()), + } + }; + + Err(SubscriptionError::SideEffect(match single { + CrudExpr::Query(query) => return SupportedQuery::new(query, input.to_string()), + CrudExpr::Insert { .. } => Crud::Insert, + CrudExpr::Update { .. } => Crud::Update, + CrudExpr::Delete { .. } => Crud::Delete, + CrudExpr::SetVar { .. } => Crud::Config, + CrudExpr::ReadVar { .. } => Crud::Config, + }) + .into()) +} + /// The kind of [`QueryExpr`] currently supported for incremental evaluation. #[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)] pub enum Supported { @@ -506,7 +546,7 @@ mod tests { AND MobileEntityState.location_z < 192000"; let tx = db.begin_tx(Workload::ForTests); - let qset = compile_read_only_query(&db, &AuthCtx::for_testing(), &tx, sql_query)?; + let qset = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, sql_query)?; for q in qset { let result = run_query( @@ -592,7 +632,7 @@ mod tests { "SELECT * FROM lhs WHERE id > 5", ]; for scan in scans { - let expr = compile_read_only_query(&db, &AuthCtx::for_testing(), &tx, scan)? + let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, scan)? .pop() .unwrap(); assert_eq!(expr.kind(), Supported::Select, "{scan}\n{expr:#?}"); @@ -601,7 +641,7 @@ mod tests { // Only index semijoins are supported let joins = ["SELECT lhs.* FROM lhs JOIN rhs ON lhs.id = rhs.id WHERE rhs.y < 10"]; for join in joins { - let expr = compile_read_only_query(&db, &AuthCtx::for_testing(), &tx, join)? + let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, join)? .pop() .unwrap(); assert_eq!(expr.kind(), Supported::Semijoin, "{join}\n{expr:#?}"); @@ -614,7 +654,7 @@ mod tests { "SELECT * FROM lhs JOIN rhs ON lhs.id = rhs.id WHERE lhs.x < 10", ]; for join in joins { - match compile_read_only_query(&db, &AuthCtx::for_testing(), &tx, join) { + match compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, join) { Err(DBError::Subscription(SubscriptionError::Unsupported(_)) | DBError::TypeError(_)) => (), x => panic!("Unexpected: {x:?}"), } diff --git a/crates/sdk/src/db_connection.rs b/crates/sdk/src/db_connection.rs index 44d73b69c3..bb455d644c 100644 --- a/crates/sdk/src/db_connection.rs +++ b/crates/sdk/src/db_connection.rs @@ -983,6 +983,9 @@ async fn parse_loop( ws::ServerMessage::OneOffQueryResponse(_) => { unreachable!("The Rust SDK does not implement one-off queries") } + ws::ServerMessage::SubscribeApplied(_) => todo!(), + ws::ServerMessage::UnsubscribeApplied(_) => todo!(), + ws::ServerMessage::SubscriptionError(_) => todo!(), }) .expect("Failed to send ParsedMessage to main thread"); } diff --git a/smoketests/config.toml b/smoketests/config.toml index e61cf87d07..cd91a9c8a3 100644 --- a/smoketests/config.toml +++ b/smoketests/config.toml @@ -1,5 +1,5 @@ default_server = "127.0.0.1:3000" -spacetimedb_token = "" +spacetimedb_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJoZXhfaWRlbnRpdHkiOiJjMjAwODIzOTA3M2M5MDgyNWQxZWY0MWVjNGJlYzg1MmNkNWIzOTdiMzBjZTVhZjUzNmZlZGExOTE3OWM5ZTJjIiwic3ViIjoiMzVmMjQ5ZTYtZjI0NC00ZDE1LWIzYmUtM2Q5NWZjMjA4MTFmIiwiaXNzIjoibG9jYWxob3N0IiwiYXVkIjpbInNwYWNldGltZWRiIl0sImlhdCI6MTczMjYzODI4NywiZXhwIjpudWxsfQ.oVIaYaH7w8ZiuowAflzKo4BrUeGk_1WqlaySMCYqIrkzB96SxVjCQuR0PYM8dOs7WhsiXvYH7dgVxbSbVV4PGg" [[server_configs]] nickname = "localhost"