diff --git a/Cargo.lock b/Cargo.lock index c1974b67..4f79a443 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2415,7 +2415,7 @@ dependencies = [ [[package]] name = "libsqlite3-sys" version = "0.26.0" -source = "git+https://github.com/psarna/rusqlite?rev=2e28bca6#2e28bca6e8ebe8e9496b6bccb7e983ccf3dfad1f" +source = "git+https://github.com/tursodatabase/rusqlite.git?rev=a72d529#a72d529a96d5dc3f4c3181358d8bd5d3a9ead8ac" dependencies = [ "bindgen 0.65.1", "cc", @@ -3246,7 +3246,7 @@ dependencies = [ [[package]] name = "rusqlite" version = "0.29.0" -source = "git+https://github.com/psarna/rusqlite?rev=2e28bca6#2e28bca6e8ebe8e9496b6bccb7e983ccf3dfad1f" +source = "git+https://github.com/tursodatabase/rusqlite.git?rev=a72d529#a72d529a96d5dc3f4c3181358d8bd5d3a9ead8ac" dependencies = [ "bitflags 2.4.0", "fallible-iterator 0.2.0", diff --git a/Cargo.toml b/Cargo.toml index 016673d9..37b0255c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,11 +8,12 @@ members = [ ] [workspace.dependencies] -rusqlite = { version = "0.29.0", git = "https://github.com/psarna/rusqlite", rev = "2e28bca6", default-features = false, features = [ +rusqlite = { version = "0.29.0", git = "https://github.com/tursodatabase/rusqlite.git", rev = "a72d529", default-features = false, features = [ "buildtime_bindgen", "bundled-libsql-wasm-experimental", "column_decltype", - "load_extension" + "load_extension", + "modern_sqlite" ] } # Config for 'cargo dist' diff --git a/sqld-libsql-bindings/src/lib.rs b/sqld-libsql-bindings/src/lib.rs index 28fefbf7..088970e3 100644 --- a/sqld-libsql-bindings/src/lib.rs +++ b/sqld-libsql-bindings/src/lib.rs @@ -3,10 +3,12 @@ pub mod ffi; pub mod wal_hook; -use std::{ffi::CString, marker::PhantomData, ops::Deref, time::Duration}; +use std::{ffi::CString, ops::Deref, time::Duration}; pub use crate::wal_hook::WalMethodsHook; pub use once_cell::sync::Lazy; +use rusqlite::ffi::sqlite3; +use wal_hook::TransparentMethods; use self::{ ffi::{libsql_wal_methods, libsql_wal_methods_find}, @@ -22,12 +24,14 @@ pub fn get_orig_wal_methods() -> anyhow::Result<*mut libsql_wal_methods> { Ok(orig) } -pub struct Connection<'a> { +pub struct Connection { conn: rusqlite::Connection, - _pth: PhantomData<&'a mut ()>, + // Safety: _ctx MUST be dropped after the connection, because the connection has a pointer + // This pointer MUST NOT move out of the connection + _ctx: Box, } -impl Deref for Connection<'_> { +impl Deref for Connection { type Target = rusqlite::Connection; fn deref(&self) -> &Self::Target { @@ -35,27 +39,30 @@ impl Deref for Connection<'_> { } } -impl<'a> Connection<'a> { +impl Connection { /// returns a dummy, in-memory connection. For testing purposes only - pub fn test(_: &mut ()) -> Self { + pub fn test() -> Self { let conn = rusqlite::Connection::open_in_memory().unwrap(); Self { conn, - _pth: PhantomData, + _ctx: Box::new(()), } } +} +impl Connection { /// Opens a database with the regular wal methods in the directory pointed to by path - pub fn open( + pub fn open( path: impl AsRef, flags: rusqlite::OpenFlags, // we technically _only_ need to know about W, but requiring a static ref to the wal_hook ensures that // it has been instanciated and lives for long enough _wal_hook: &'static WalMethodsHook, - hook_ctx: &'a mut W::Context, + hook_ctx: W::Context, auto_checkpoint: u32, ) -> Result { let path = path.as_ref().join("data"); + let mut _ctx = Box::new(hook_ctx); tracing::trace!( "Opening a connection with regular WAL at {}", path.display() @@ -75,7 +82,7 @@ impl<'a> Connection<'a> { flags.bits(), std::ptr::null_mut(), W::name().as_ptr(), - hook_ctx as *mut _ as *mut _, + _ctx.as_mut() as *mut _ as *mut _, ); if rc == 0 { @@ -96,9 +103,14 @@ impl<'a> Connection<'a> { let conn = unsafe { rusqlite::Connection::from_handle_owned(db)? }; conn.busy_timeout(Duration::from_millis(5000))?; - Ok(Connection { - conn, - _pth: PhantomData, - }) + Ok(Connection { conn, _ctx }) + } + + /// Returns the raw sqlite handle + /// + /// # Safety + /// The caller is responsible for the returned pointer. + pub unsafe fn handle(&mut self) -> *mut sqlite3 { + self.conn.handle() } } diff --git a/sqld/src/connection/libsql.rs b/sqld/src/connection/libsql.rs index eccfe57a..29abf989 100644 --- a/sqld/src/connection/libsql.rs +++ b/sqld/src/connection/libsql.rs @@ -1,12 +1,13 @@ +use std::ffi::{c_int, c_void}; use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use crossbeam::channel::RecvTimeoutError; -use rusqlite::{ErrorCode, OpenFlags, StatementStatus}; -use sqld_libsql_bindings::wal_hook::WalMethodsHook; -use tokio::sync::{oneshot, watch}; +use parking_lot::{Mutex, RwLock}; +use rusqlite::{DatabaseName, ErrorCode, OpenFlags, StatementStatus}; +use sqld_libsql_bindings::wal_hook::{TransparentMethods, WalMethodsHook}; +use tokio::sync::{watch, Notify}; use tokio::time::{Duration, Instant}; -use tracing::warn; use crate::auth::{Authenticated, Authorized, Permission}; use crate::error::Error; @@ -22,10 +23,7 @@ use super::config::DatabaseConfigStore; use super::program::{Cond, DescribeCol, DescribeParam, DescribeResponse, DescribeResult}; use super::{MakeConnection, Program, Step, TXN_TIMEOUT}; -/// Internal message used to communicate between the database thread and the `LibSqlDb` handle. -type ExecCallback = Box) -> anyhow::Result<()> + Send + 'static>; - -pub struct LibSqlDbFactory { +pub struct MakeLibSqlConn { db_path: PathBuf, hook: &'static WalMethodsHook, ctx_builder: Box W::Context + Sync + Send + 'static>, @@ -36,12 +34,13 @@ pub struct LibSqlDbFactory { max_total_response_size: u64, auto_checkpoint: u32, current_frame_no_receiver: watch::Receiver>, + state: Arc>, /// In wal mode, closing the last database takes time, and causes other databases creation to /// return sqlite busy. To mitigate that, we hold on to one connection - _db: Option, + _db: Option>, } -impl LibSqlDbFactory +impl MakeLibSqlConn where W: WalHook + 'static + Sync + Send, W::Context: Send + 'static, @@ -74,6 +73,7 @@ where auto_checkpoint, current_frame_no_receiver, _db: None, + state: Default::default(), }; let db = this.try_create_db().await?; @@ -83,11 +83,11 @@ where } /// Tries to create a database, retrying if the database is busy. - async fn try_create_db(&self) -> Result { + async fn try_create_db(&self) -> Result> { // try 100 times to acquire initial db connection. let mut retries = 0; loop { - match self.create_database().await { + match self.make_connection().await { Ok(conn) => return Ok(conn), Err( err @ Error::RusqliteError(rusqlite::Error::SqliteFailure( @@ -111,7 +111,7 @@ where } } - async fn create_database(&self) -> Result { + async fn make_connection(&self) -> Result> { LibSqlConnection::new( self.db_path.clone(), self.extensions.clone(), @@ -125,36 +125,37 @@ where auto_checkpoint: self.auto_checkpoint, }, self.current_frame_no_receiver.clone(), + self.state.clone(), ) .await } } #[async_trait::async_trait] -impl MakeConnection for LibSqlDbFactory +impl MakeConnection for MakeLibSqlConn where W: WalHook + 'static + Sync + Send, W::Context: Send + 'static, { - type Connection = LibSqlConnection; + type Connection = LibSqlConnection; async fn create(&self) -> Result { - self.create_database().await + self.make_connection().await } } -#[derive(Clone, Debug)] -pub struct LibSqlConnection { - sender: crossbeam::channel::Sender, +#[derive(Clone)] +pub struct LibSqlConnection { + inner: Arc>>, } -pub fn open_db<'a, W>( +pub fn open_conn( path: &Path, wal_methods: &'static WalMethodsHook, - hook_ctx: &'a mut W::Context, + hook_ctx: W::Context, flags: Option, auto_checkpoint: u32, -) -> Result, rusqlite::Error> +) -> Result, rusqlite::Error> where W: WalHook, { @@ -167,8 +168,12 @@ where sqld_libsql_bindings::Connection::open(path, flags, wal_methods, hook_ctx, auto_checkpoint) } -impl LibSqlConnection { - pub async fn new( +impl LibSqlConnection +where + W: WalHook, + W::Context: Send, +{ + pub async fn new( path: impl AsRef + Send + 'static, extensions: Arc<[PathBuf]>, wal_hook: &'static WalMethodsHook, @@ -177,109 +182,170 @@ impl LibSqlConnection { config_store: Arc, builder_config: QueryBuilderConfig, current_frame_no_receiver: watch::Receiver>, - ) -> crate::Result - where - W: WalHook, - W::Context: Send, - { - let (sender, receiver) = crossbeam::channel::unbounded::(); - let (init_sender, init_receiver) = oneshot::channel(); - - crate::BLOCKING_RT.spawn_blocking(move || { - let mut ctx = hook_ctx; - let mut connection = match Connection::new( + state: Arc>, + ) -> crate::Result { + let conn = tokio::task::spawn_blocking(move || { + Connection::new( path.as_ref(), extensions, wal_hook, - &mut ctx, + hook_ctx, stats, config_store, builder_config, current_frame_no_receiver, - ) { - Ok(conn) => { - let Ok(_) = init_sender.send(Ok(())) else { return }; - conn - } - Err(e) => { - let _ = init_sender.send(Err(e)); - return; - } - }; - - loop { - let exec = match connection.timeout_deadline { - Some(deadline) => match receiver.recv_deadline(deadline.into()) { - Ok(msg) => msg, - Err(RecvTimeoutError::Timeout) => { - warn!("transaction timed out"); - connection.rollback(); - connection.timed_out = true; - connection.timeout_deadline = None; - continue; - } - Err(RecvTimeoutError::Disconnected) => break, - }, - None => match receiver.recv() { - Ok(msg) => msg, - Err(_) => break, - }, - }; - - let maybe_conn = if !connection.timed_out { - Ok(&mut connection) - } else { - Err(Error::LibSqlTxTimeout) - }; - - if exec(maybe_conn).is_err() { - tracing::warn!("Database connection closed unexpectedly"); - return; - }; - } - }); - - init_receiver.await??; + state, + ) + }) + .await + .unwrap()?; - Ok(Self { sender }) + Ok(Self { + inner: Arc::new(Mutex::new(conn)), + }) } } -struct Connection<'a> { - timeout_deadline: Option, - conn: sqld_libsql_bindings::Connection<'a>, - timed_out: bool, +struct Connection { + conn: sqld_libsql_bindings::Connection, stats: Arc, config_store: Arc, builder_config: QueryBuilderConfig, current_frame_no_receiver: watch::Receiver>, + // must be dropped after the connection because the connection refers to it + state: Arc>, + // current txn slot if any + slot: Option>>, +} + +/// A slot for holding the state of a transaction lock permit +struct TxnSlot { + /// Pointer to the connection holding the lock. Used to rollback the transaction when the lock + /// is stolen. + conn: Arc>>, + /// Time at which the transaction can be stolen + timeout_at: tokio::time::Instant, + /// The transaction lock was stolen + is_stolen: AtomicBool, +} + +/// The transaction state shared among all connections to the same database +pub struct TxnState { + /// Slot for the connection currently holding the transaction lock + slot: RwLock>>>, + /// Notifier for when the lock gets dropped + notify: Notify, +} + +impl Default for TxnState { + fn default() -> Self { + Self { + slot: Default::default(), + notify: Default::default(), + } + } } -impl<'a> Connection<'a> { - fn new( +/// The lock-stealing busy handler. +/// Here is a detailed description of the algorithm: +/// - all connections to a database share a `TxnState`, that contains a `TxnSlot` +/// - when a connection acquire a write lock to the database, this is detected by monitoring the state of the +/// connection before and after the call thanks to [sqlite3_txn_state()](https://www.sqlite.org/c3ref/c_txn_none.html) +/// - if the connection acquired a write lock (txn state none/read -> write), a new txn slot is created. A clone of the +/// `TxnSlot` is placed in the `TxnState` shared with other connections to this database, while another clone is kept in +/// the transaction state. The TxnSlot contains: the instant at which the txn should timeout, a `is_stolen` flag, and a +/// pointer to the connection currently holding the lock. +/// - when another connection attempts to acquire the lock, the `busy_handler` callback will be called. The callback is being +/// passed the `TxnState` for the connection. The handler looks at the current slot to determine when the current txn will +/// timeout, and waits for that instant before retrying. The waiting handler can also be notified that the transaction has +/// been finished early. +/// - If the handler waits until the txn timeout and isn't notified of the termination of the txn, it will attempt to steal the lock. +/// This is done by calling rollback on the slot's txn, and marking the slot as stolen. +/// - When a connection notices that it's slot has been stolen, it returns a timedout error to the next request. +unsafe extern "C" fn busy_handler(state: *mut c_void, _retries: c_int) -> c_int { + let state = &*(state as *mut TxnState); + let lock = state.slot.read(); + // we take a reference to the slot we will attempt to steal. this is to make sure that we + // actually steal the correct lock. + let slot = match &*lock { + Some(slot) => slot.clone(), + // fast path: there is no slot, try to acquire the lock again + None => return 1, + }; + + tokio::runtime::Handle::current().block_on(async move { + let timeout = { + let slot = lock.as_ref().unwrap(); + let timeout_at = slot.timeout_at; + drop(lock); + tokio::time::sleep_until(timeout_at) + }; + + tokio::select! { + // The connection has notified us that it's txn has terminated, try to acquire again + _ = state.notify.notified() => 1, + // the current holder of the transaction has timedout, we will attempt to steal their + // lock. + _ = timeout => { + // only a single connection gets to steal the lock, others retry + if let Some(mut lock) = state.slot.try_write() { + // We check that slot wasn't already stolen, and that their is still a slot. + // The ordering is relaxed because the atomic is only set under the slot lock. + if slot.is_stolen.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed).is_ok() { + // The connection holding the current txn will sets itsef as stolen when it + // detects a timeout, so if we arrive to this point, then there is + // necessarily a slot, and this slot has to be the one we attempted to + // steal. + assert!(lock.take().is_some()); + + let conn = slot.conn.lock(); + // we have a lock on the connection, we don't need mode than a + // Relaxed store. + conn.rollback(); + + tracing::info!("stole transaction lock"); + } + } + 1 + } + } + }) +} + +impl Connection { + fn new( path: &Path, extensions: Arc<[PathBuf]>, wal_methods: &'static WalMethodsHook, - hook_ctx: &'a mut W::Context, + hook_ctx: W::Context, stats: Arc, config_store: Arc, builder_config: QueryBuilderConfig, current_frame_no_receiver: watch::Receiver>, + state: Arc>, ) -> Result { + let mut conn = open_conn( + path, + wal_methods, + hook_ctx, + None, + builder_config.auto_checkpoint, + )?; + + // register the lock-stealing busy handler + unsafe { + let ptr = Arc::as_ptr(&state) as *mut _; + rusqlite::ffi::sqlite3_busy_handler(conn.handle(), Some(busy_handler::), ptr); + } + let this = Self { - conn: open_db( - path, - wal_methods, - hook_ctx, - None, - builder_config.auto_checkpoint, - )?, - timeout_deadline: None, - timed_out: false, + conn, stats, config_store, builder_config, current_frame_no_receiver, + state, + slot: None, }; for ext in extensions.iter() { @@ -296,25 +362,88 @@ impl<'a> Connection<'a> { Ok(this) } - fn run(&mut self, pgm: Program, mut builder: B) -> Result { - let mut results = Vec::with_capacity(pgm.steps.len()); + fn run( + this: Arc>, + pgm: Program, + mut builder: B, + ) -> Result<(B, State)> { + use rusqlite::TransactionState as Tx; - builder.init(&self.builder_config)?; - let is_autocommit_before = self.conn.is_autocommit(); + let state = this.lock().state.clone(); + let mut results = Vec::with_capacity(pgm.steps.len()); + builder.init(&this.lock().builder_config)?; + let mut previous_state = this + .lock() + .conn + .transaction_state(Some(DatabaseName::Main))?; + + let mut has_timeout = false; for step in pgm.steps() { - let res = self.execute_step(step, &results, &mut builder)?; + let mut lock = this.lock(); + + if let Some(slot) = &lock.slot { + if slot.is_stolen.load(Ordering::Relaxed) || Instant::now() > slot.timeout_at { + // we mark ourselves as stolen to notify any waiting lock thief. + slot.is_stolen.store(true, Ordering::Relaxed); + lock.rollback(); + has_timeout = true; + } + } + + // once there was a timeout, invalidate all the program steps + if has_timeout { + lock.slot = None; + builder.begin_step()?; + builder.step_error(Error::LibSqlTxTimeout)?; + builder.finish_step(0, None)?; + continue; + } + + let res = lock.execute_step(step, &results, &mut builder)?; + + let new_state = lock.conn.transaction_state(Some(DatabaseName::Main))?; + match (previous_state, new_state) { + // lock was upgraded, claim the slot + (Tx::None | Tx::Read, Tx::Write) => { + let slot = Arc::new(TxnSlot { + conn: this.clone(), + timeout_at: Instant::now() + TXN_TIMEOUT, + is_stolen: AtomicBool::new(false), + }); + + lock.slot.replace(slot.clone()); + state.slot.write().replace(slot); + } + // lock was downgraded, notify a waiter + (Tx::Write, Tx::None | Tx::Read) => { + state.slot.write().take(); + lock.slot.take(); + state.notify.notify_one(); + } + // nothing to do + (_, _) => (), + } + + previous_state = new_state; + results.push(res); } - // A transaction is still open, set up a timeout - if is_autocommit_before && !self.conn.is_autocommit() { - self.timeout_deadline = Some(Instant::now() + TXN_TIMEOUT) - } + builder.finish(*this.lock().current_frame_no_receiver.borrow_and_update())?; - builder.finish(*self.current_frame_no_receiver.borrow_and_update())?; + let state = if matches!( + this.lock() + .conn + .transaction_state(Some(DatabaseName::Main))?, + Tx::Read | Tx::Write + ) { + State::Txn + } else { + State::Init + }; - Ok(builder) + Ok((builder, state)) } fn execute_step( @@ -537,7 +666,11 @@ fn check_describe_auth(auth: Authenticated) -> Result<()> { } #[async_trait::async_trait] -impl super::Connection for LibSqlConnection { +impl super::Connection for LibSqlConnection +where + W: WalHook + 'static, + W::Context: Send, +{ async fn execute_program( &self, pgm: Program, @@ -546,29 +679,10 @@ impl super::Connection for LibSqlConnection { _replication_index: Option, ) -> Result<(B, State)> { check_program_auth(auth, &pgm)?; - let (resp, receiver) = oneshot::channel(); - let cb = Box::new(move |maybe_conn: Result<&mut Connection>| { - let res = maybe_conn.and_then(|c| { - let b = c.run(pgm, builder)?; - let state = if c.conn.is_autocommit() { - State::Init - } else { - State::Txn - }; - - Ok((b, state)) - }); - - if resp.send(res).is_err() { - anyhow::bail!("connection closed"); - } - - Ok(()) - }); - - let _: Result<_, _> = self.sender.send(cb); - - Ok(receiver.await??) + let conn = self.inner.clone(); + tokio::task::spawn_blocking(move || Connection::run(conn, pgm, builder)) + .await + .unwrap() } async fn describe( @@ -578,74 +692,57 @@ impl super::Connection for LibSqlConnection { _replication_index: Option, ) -> Result { check_describe_auth(auth)?; - let (resp, receiver) = oneshot::channel(); - let cb = Box::new(move |maybe_conn: Result<&mut Connection>| { - let res = maybe_conn.and_then(|c| c.describe(&sql)); + let conn = self.inner.clone(); + let res = tokio::task::spawn_blocking(move || conn.lock().describe(&sql)) + .await + .unwrap(); - if resp.send(res).is_err() { - anyhow::bail!("connection closed"); - } - - Ok(()) - }); - - let _: Result<_, _> = self.sender.send(cb); - - Ok(receiver.await?) + Ok(res) } async fn is_autocommit(&self) -> Result { - let (resp, receiver) = oneshot::channel(); - let cb = Box::new(move |maybe_conn: Result<&mut Connection>| { - let res = maybe_conn.map(|c| c.is_autocommit()); - if resp.send(res).is_err() { - anyhow::bail!("connection closed"); - } - Ok(()) - }); - - let _: Result<_, _> = self.sender.send(cb); - receiver.await? + Ok(self.inner.lock().is_autocommit()) } async fn checkpoint(&self) -> Result<()> { - let (resp, receiver) = oneshot::channel(); - let cb = Box::new(move |maybe_conn: Result<&mut Connection>| { - let res = maybe_conn.and_then(|c| c.checkpoint()); - if resp.send(res).is_err() { - anyhow::bail!("connection closed"); - } - Ok(()) - }); - - let _: Result<_, _> = self.sender.send(cb); - receiver.await? + let conn = self.inner.clone(); + tokio::task::spawn_blocking(move || conn.lock().checkpoint()) + .await + .unwrap()?; + Ok(()) } } #[cfg(test)] mod test { use itertools::Itertools; + use sqld_libsql_bindings::wal_hook::TRANSPARENT_METHODS; + use tempfile::tempdir; + use tokio::task::JoinSet; - use crate::query_result_builder::{test::test_driver, IgnoreResult}; + use crate::query_result_builder::test::{test_driver, TestBuilder}; + use crate::query_result_builder::QueryResultBuilder; + use crate::DEFAULT_AUTO_CHECKPOINT; use super::*; - fn setup_test_conn(ctx: &mut ()) -> Connection { - let mut conn = Connection { - timeout_deadline: None, - conn: sqld_libsql_bindings::Connection::test(ctx), - timed_out: false, + fn setup_test_conn() -> Arc> { + let conn = Connection { + conn: sqld_libsql_bindings::Connection::test(), stats: Arc::new(Stats::default()), config_store: Arc::new(DatabaseConfigStore::new_test()), builder_config: QueryBuilderConfig::default(), current_frame_no_receiver: watch::channel(None).1, + state: Default::default(), + slot: None, }; + let conn = Arc::new(Mutex::new(conn)); + let stmts = std::iter::once("create table test (x)") .chain(std::iter::repeat("insert into test values ('hello world')").take(100)) .collect_vec(); - conn.run(Program::seq(&stmts), IgnoreResult).unwrap(); + Connection::run(conn.clone(), Program::seq(&stmts), TestBuilder::default()).unwrap(); conn } @@ -653,9 +750,167 @@ mod test { #[test] fn test_libsql_conn_builder_driver() { test_driver(1000, |b| { - let ctx = &mut (); - let mut conn = setup_test_conn(ctx); - conn.run(Program::seq(&["select * from test"]), b) + let conn = setup_test_conn(); + Connection::run(conn, Program::seq(&["select * from test"]), b).map(|x| x.0) }) } + + #[tokio::test] + async fn txn_timeout_no_stealing() { + let tmp = tempdir().unwrap(); + let make_conn = MakeLibSqlConn::new( + tmp.path().into(), + &TRANSPARENT_METHODS, + || (), + Default::default(), + Arc::new(DatabaseConfigStore::load(tmp.path()).unwrap()), + Arc::new([]), + 100000000, + 100000000, + DEFAULT_AUTO_CHECKPOINT, + watch::channel(None).1, + ) + .await + .unwrap(); + + tokio::time::pause(); + let conn = make_conn.make_connection().await.unwrap(); + let (_builder, state) = Connection::run( + conn.inner.clone(), + Program::seq(&["BEGIN IMMEDIATE"]), + TestBuilder::default(), + ) + .unwrap(); + assert_eq!(state, State::Txn); + + tokio::time::advance(TXN_TIMEOUT * 2).await; + + let (builder, state) = Connection::run( + conn.inner.clone(), + Program::seq(&["BEGIN IMMEDIATE"]), + TestBuilder::default(), + ) + .unwrap(); + assert_eq!(state, State::Init); + assert!(matches!(builder.into_ret()[0], Err(Error::LibSqlTxTimeout))); + } + + #[tokio::test] + /// A bunch of txn try to acquire the lock, and never release it. They will try to steal the + /// lock one after the other. All txn should eventually acquire the write lock + async fn serialized_txn_timeouts() { + let tmp = tempdir().unwrap(); + let make_conn = MakeLibSqlConn::new( + tmp.path().into(), + &TRANSPARENT_METHODS, + || (), + Default::default(), + Arc::new(DatabaseConfigStore::load(tmp.path()).unwrap()), + Arc::new([]), + 100000000, + 100000000, + DEFAULT_AUTO_CHECKPOINT, + watch::channel(None).1, + ) + .await + .unwrap(); + + let mut set = JoinSet::new(); + for _ in 0..10 { + let conn = make_conn.make_connection().await.unwrap(); + set.spawn_blocking(move || { + let (builder, state) = Connection::run( + conn.inner, + Program::seq(&["BEGIN IMMEDIATE"]), + TestBuilder::default(), + ) + .unwrap(); + assert_eq!(state, State::Txn); + assert!(builder.into_ret()[0].is_ok()); + }); + } + + tokio::time::pause(); + + while let Some(ret) = set.join_next().await { + assert!(ret.is_ok()); + // advance time by a bit more than the txn timeout + tokio::time::advance(TXN_TIMEOUT + Duration::from_millis(100)).await; + } + } + + #[tokio::test] + /// verify that releasing a txn before the timeout + async fn release_before_timeout() { + let tmp = tempdir().unwrap(); + let make_conn = MakeLibSqlConn::new( + tmp.path().into(), + &TRANSPARENT_METHODS, + || (), + Default::default(), + Arc::new(DatabaseConfigStore::load(tmp.path()).unwrap()), + Arc::new([]), + 100000000, + 100000000, + DEFAULT_AUTO_CHECKPOINT, + watch::channel(None).1, + ) + .await + .unwrap(); + + let conn1 = make_conn.make_connection().await.unwrap(); + tokio::task::spawn_blocking({ + let conn = conn1.inner.clone(); + move || { + let (builder, state) = Connection::run( + conn, + Program::seq(&["BEGIN IMMEDIATE"]), + TestBuilder::default(), + ) + .unwrap(); + assert_eq!(state, State::Txn); + assert!(builder.into_ret()[0].is_ok()); + } + }) + .await + .unwrap(); + + let conn2 = make_conn.make_connection().await.unwrap(); + let handle = tokio::task::spawn_blocking({ + let conn = conn2.inner.clone(); + move || { + let before = Instant::now(); + let (builder, state) = Connection::run( + conn, + Program::seq(&["BEGIN IMMEDIATE"]), + TestBuilder::default(), + ) + .unwrap(); + assert_eq!(state, State::Txn); + assert!(builder.into_ret()[0].is_ok()); + before.elapsed() + } + }); + + let wait_time = TXN_TIMEOUT / 10; + tokio::time::sleep(wait_time).await; + + tokio::task::spawn_blocking({ + let conn = conn1.inner.clone(); + move || { + let (builder, state) = + Connection::run(conn, Program::seq(&["COMMIT"]), TestBuilder::default()) + .unwrap(); + assert_eq!(state, State::Init); + assert!(builder.into_ret()[0].is_ok()); + } + }) + .await + .unwrap(); + + let elapsed = handle.await.unwrap(); + + let epsilon = Duration::from_millis(100); + assert!((wait_time..wait_time + epsilon).contains(&elapsed)); + } } diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index 24b49aea..7c986928 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use parking_lot::Mutex as PMutex; use rusqlite::types::ValueRef; -use sqld_libsql_bindings::wal_hook::TRANSPARENT_METHODS; +use sqld_libsql_bindings::wal_hook::{TransparentMethods, TRANSPARENT_METHODS}; use tokio::sync::{watch, Mutex}; use tonic::metadata::BinaryMetadataValue; use tonic::transport::Channel; @@ -27,27 +27,24 @@ use crate::stats::Stats; use crate::{Result, DEFAULT_AUTO_CHECKPOINT}; use super::config::DatabaseConfigStore; -use super::libsql::LibSqlConnection; +use super::libsql::{LibSqlConnection, MakeLibSqlConn}; use super::program::DescribeResult; use super::Connection; use super::{MakeConnection, Program}; -#[derive(Clone)] -pub struct MakeWriteProxyConnection { +pub struct MakeWriteProxyConn { client: ProxyClient, - db_path: PathBuf, - extensions: Arc<[PathBuf]>, stats: Arc, - config_store: Arc, applied_frame_no_receiver: watch::Receiver>, max_response_size: u64, max_total_response_size: u64, namespace: NamespaceName, + make_read_only_conn: MakeLibSqlConn, } -impl MakeWriteProxyConnection { +impl MakeWriteProxyConn { #[allow(clippy::too_many_arguments)] - pub fn new( + pub async fn new( db_path: PathBuf, extensions: Arc<[PathBuf]>, channel: Channel, @@ -58,32 +55,41 @@ impl MakeWriteProxyConnection { max_response_size: u64, max_total_response_size: u64, namespace: NamespaceName, - ) -> Self { + ) -> crate::Result { let client = ProxyClient::with_origin(channel, uri); - Self { + let make_read_only_conn = MakeLibSqlConn::new( + db_path.clone(), + &TRANSPARENT_METHODS, + || (), + stats.clone(), + config_store.clone(), + extensions.clone(), + max_response_size, + max_total_response_size, + DEFAULT_AUTO_CHECKPOINT, + applied_frame_no_receiver.clone(), + ) + .await?; + + Ok(Self { client, - db_path, - extensions, stats, - config_store, applied_frame_no_receiver, max_response_size, max_total_response_size, namespace, - } + make_read_only_conn, + }) } } #[async_trait::async_trait] -impl MakeConnection for MakeWriteProxyConnection { +impl MakeConnection for MakeWriteProxyConn { type Connection = WriteProxyConnection; async fn create(&self) -> Result { let db = WriteProxyConnection::new( self.client.clone(), - self.db_path.clone(), - self.extensions.clone(), self.stats.clone(), - self.config_store.clone(), self.applied_frame_no_receiver.clone(), QueryBuilderConfig { max_size: Some(self.max_response_size), @@ -91,6 +97,7 @@ impl MakeConnection for MakeWriteProxyConnection { auto_checkpoint: DEFAULT_AUTO_CHECKPOINT, }, self.namespace.clone(), + self.make_read_only_conn.create().await?, ) .await?; Ok(db) @@ -99,7 +106,7 @@ impl MakeConnection for MakeWriteProxyConnection { pub struct WriteProxyConnection { /// Lazily initialized read connection - read_conn: LibSqlConnection, + read_conn: LibSqlConnection, write_proxy: ProxyClient, state: Mutex, client_id: Uuid, @@ -163,26 +170,12 @@ impl WriteProxyConnection { #[allow(clippy::too_many_arguments)] async fn new( write_proxy: ProxyClient, - db_path: PathBuf, - extensions: Arc<[PathBuf]>, stats: Arc, - config_store: Arc, applied_frame_no_receiver: watch::Receiver>, builder_config: QueryBuilderConfig, namespace: NamespaceName, + read_conn: LibSqlConnection, ) -> Result { - let read_conn = LibSqlConnection::new( - db_path, - extensions, - &TRANSPARENT_METHODS, - (), - stats.clone(), - config_store, - builder_config, - applied_frame_no_receiver.clone(), - ) - .await?; - Ok(Self { read_conn, write_proxy, diff --git a/sqld/src/database.rs b/sqld/src/database.rs index 5b65e3c9..60ca8e4b 100644 --- a/sqld/src/database.rs +++ b/sqld/src/database.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use crate::connection::libsql::LibSqlConnection; use crate::connection::write_proxy::WriteProxyConnection; use crate::connection::{Connection, MakeConnection, TrackedConnection}; -use crate::replication::ReplicationLogger; +use crate::replication::{ReplicationLogger, ReplicationLoggerHook}; pub trait Database: Sync + Send + 'static { /// The connection type of the database @@ -28,13 +28,15 @@ impl Database for ReplicaDatabase { fn shutdown(&self) {} } +pub type PrimaryConnection = TrackedConnection>; + pub struct PrimaryDatabase { pub logger: Arc, - pub connection_maker: Arc>>, + pub connection_maker: Arc>, } impl Database for PrimaryDatabase { - type Connection = TrackedConnection; + type Connection = PrimaryConnection; fn connection_maker(&self) -> Arc> { self.connection_maker.clone() diff --git a/sqld/src/namespace/mod.rs b/sqld/src/namespace/mod.rs index 6651fe6d..d54499b5 100644 --- a/sqld/src/namespace/mod.rs +++ b/sqld/src/namespace/mod.rs @@ -24,8 +24,8 @@ use uuid::Uuid; use crate::auth::Authenticated; use crate::connection::config::DatabaseConfigStore; -use crate::connection::libsql::{open_db, LibSqlDbFactory}; -use crate::connection::write_proxy::MakeWriteProxyConnection; +use crate::connection::libsql::{open_conn, MakeLibSqlConn}; +use crate::connection::write_proxy::MakeWriteProxyConn; use crate::connection::MakeConnection; use crate::database::{Database, PrimaryDatabase, ReplicaDatabase}; use crate::error::{Error, LoadDumpError}; @@ -577,7 +577,7 @@ impl Namespace { join_set.spawn(replicator.run()); - let connection_maker = MakeWriteProxyConnection::new( + let connection_maker = MakeWriteProxyConn::new( db_path.clone(), config.extensions.clone(), config.channel.clone(), @@ -589,6 +589,7 @@ impl Namespace { config.max_total_response_size, name.clone(), ) + .await? .throttled( MAX_CONCURRENT_DBS, Some(DB_CREATE_TIMEOUT), @@ -718,7 +719,7 @@ impl Namespace { DatabaseConfigStore::load(&db_path).context("Could not load database config")?, ); - let connection_maker: Arc<_> = LibSqlDbFactory::new( + let connection_maker: Arc<_> = MakeLibSqlConn::new( db_path.clone(), &REPLICATION_METHODS, ctx_builder.clone(), @@ -738,13 +739,12 @@ impl Namespace { ) .into(); - let mut ctx = ctx_builder(); match restore_option { RestoreOption::Dump(_) if !is_fresh_db => { Err(LoadDumpError::LoadDumpExistingDb)?; } RestoreOption::Dump(dump) => { - load_dump(&db_path, dump, &mut ctx).await?; + load_dump(&db_path, dump, ctx_builder, logger.auto_checkpoint).await?; } _ => { /* other cases were already handled when creating bottomless */ } } @@ -828,17 +828,24 @@ const WASM_TABLE_CREATE: &str = async fn load_dump( db_path: &Path, dump: S, - ctx: &mut ReplicationLoggerHookCtx, + mk_ctx: impl Fn() -> ReplicationLoggerHookCtx, + auto_checkpoint: u32, ) -> anyhow::Result<()> where S: Stream> + Unpin, { let mut retries = 0; - let auto_checkpoint = ctx.logger().auto_checkpoint; // there is a small chance we fail to acquire the lock right away, so we perform a few retries let conn = loop { - match block_in_place(|| open_db(db_path, &REPLICATION_METHODS, ctx, None, auto_checkpoint)) - { + match block_in_place(|| { + open_conn( + db_path, + &REPLICATION_METHODS, + mk_ctx(), + None, + auto_checkpoint, + ) + }) { Ok(conn) => { break conn; } @@ -977,10 +984,9 @@ async fn run_storage_monitor(db_path: PathBuf, stats: Weak) -> anyhow::Re // because closing the last connection interferes with opening a new one, we lazily // initialize a connection here, and keep it alive for the entirety of the program. If we // fail to open it, we wait for `duration` and try again later. - let ctx = &mut (); // We can safely open db with DEFAULT_AUTO_CHECKPOINT, since monitor is read-only: it // won't produce new updates, frames or generate checkpoints. - match open_db(&db_path, &TRANSPARENT_METHODS, ctx, Some(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY), DEFAULT_AUTO_CHECKPOINT) { + match open_conn(&db_path, &TRANSPARENT_METHODS, (), Some(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY), DEFAULT_AUTO_CHECKPOINT) { Ok(conn) => { if let Ok(storage_bytes_used) = conn.query_row("select sum(pgsize) from dbstat;", [], |row| { diff --git a/sqld/src/query_result_builder.rs b/sqld/src/query_result_builder.rs index 36820339..914037ee 100644 --- a/sqld/src/query_result_builder.rs +++ b/sqld/src/query_result_builder.rs @@ -503,8 +503,111 @@ pub mod test { }; use FsmState::*; + use crate::query::Value; + use super::*; + #[derive(Default)] + pub struct TestBuilder { + steps: Vec, + current_step: StepResultBuilder, + } + + pub type Row = Vec; + pub type StepResult = crate::Result>; + + #[derive(Default)] + pub struct StepResultBuilder { + rows: Vec, + current_row: Row, + err: Option, + } + + impl QueryResultBuilder for TestBuilder { + type Ret = Vec; + + fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + self.steps.clear(); + self.current_step = Default::default(); + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn finish_step( + &mut self, + _affected_row_count: u64, + _last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + let current = std::mem::take(&mut self.current_step); + if let Some(err) = current.err { + self.steps.push(Err(err)); + } else { + self.steps.push(Ok(current.rows)); + } + + Ok(()) + } + + fn step_error( + &mut self, + error: crate::error::Error, + ) -> Result<(), QueryResultBuilderError> { + self.current_step.err = Some(error); + Ok(()) + } + + fn cols_description<'a>( + &mut self, + _cols: impl IntoIterator>>, + ) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { + let v = match v { + ValueRef::Null => Value::Null, + ValueRef::Integer(i) => Value::Integer(i), + ValueRef::Real(x) => Value::Real(x), + ValueRef::Text(s) => Value::Text(String::from_utf8(s.to_vec()).unwrap()), + ValueRef::Blob(x) => Value::Blob(x.to_vec()), + }; + self.current_step.current_row.push(v); + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + let row = std::mem::take(&mut self.current_step.current_row); + self.current_step.rows.push(row); + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn finish( + &mut self, + _last_frame_no: Option, + ) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn into_ret(self) -> Self::Ret { + self.steps + } + } + /// a dummy QueryResultBuilder that encodes the QueryResultBuilder FSM. It can be passed to a /// driver to ensure that it is not mis-used diff --git a/sqld/src/replication/replica/injector.rs b/sqld/src/replication/replica/injector.rs index 9d32223d..28bdd333 100644 --- a/sqld/src/replication/replica/injector.rs +++ b/sqld/src/replication/replica/injector.rs @@ -5,14 +5,14 @@ use rusqlite::OpenFlags; use crate::replication::replica::hook::{SQLITE_CONTINUE_REPLICATION, SQLITE_EXIT_REPLICATION}; -use super::hook::{InjectorHookCtx, INJECTOR_METHODS}; +use super::hook::{InjectorHook, InjectorHookCtx, INJECTOR_METHODS}; -pub struct FrameInjector<'a> { - conn: sqld_libsql_bindings::Connection<'a>, +pub struct FrameInjector { + conn: sqld_libsql_bindings::Connection, } -impl<'a> FrameInjector<'a> { - pub fn new(db_path: &Path, hook_ctx: &'a mut InjectorHookCtx) -> anyhow::Result { +impl FrameInjector { + pub fn new(db_path: &Path, hook_ctx: InjectorHookCtx) -> anyhow::Result { let conn = sqld_libsql_bindings::Connection::open( db_path, OpenFlags::SQLITE_OPEN_READ_WRITE diff --git a/sqld/src/replication/replica/replicator.rs b/sqld/src/replication/replica/replicator.rs index e4c8bc35..230b6ee1 100644 --- a/sqld/src/replication/replica/replicator.rs +++ b/sqld/src/replication/replica/replicator.rs @@ -109,8 +109,8 @@ impl Replicator { let handle = BLOCKING_RT.spawn_blocking({ let db_path = db_path; move || -> anyhow::Result<()> { - let mut ctx = InjectorHookCtx::new(receiver, pre_commit, post_commit); - let mut injector = FrameInjector::new(&db_path, &mut ctx)?; + let ctx = InjectorHookCtx::new(receiver, pre_commit, post_commit); + let mut injector = FrameInjector::new(&db_path, ctx)?; let _ = snd.send(()); while injector.step()? {} diff --git a/sqld/src/rpc/proxy.rs b/sqld/src/rpc/proxy.rs index 3ec7392b..5ba004a7 100644 --- a/sqld/src/rpc/proxy.rs +++ b/sqld/src/rpc/proxy.rs @@ -6,9 +6,8 @@ use async_lock::{RwLock, RwLockUpgradableReadGuard}; use uuid::Uuid; use crate::auth::{Auth, Authenticated}; -use crate::connection::libsql::LibSqlConnection; -use crate::connection::{Connection, TrackedConnection}; -use crate::database::Database; +use crate::connection::Connection; +use crate::database::{Database, PrimaryConnection}; use crate::namespace::{NamespaceStore, PrimaryNamespaceMaker}; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, @@ -265,7 +264,7 @@ pub mod rpc { } pub struct ProxyService { - clients: Arc>>>>, + clients: Arc>>>, namespaces: NamespaceStore, auth: Option>, disable_namespaces: bool, @@ -277,17 +276,15 @@ impl ProxyService { auth: Option>, disable_namespaces: bool, ) -> Self { - let clients: Arc>>>> = - Default::default(); Self { - clients, + clients: Default::default(), namespaces, auth, disable_namespaces, } } - pub fn clients(&self) -> Arc>>>> { + pub fn clients(&self) -> Arc>>> { self.clients.clone() } } @@ -451,13 +448,11 @@ impl QueryResultBuilder for ExecuteResultBuilder { // FIXME: we should also keep a list of recently disconnected clients, // and if one should arrive with a late message, it should be rejected // with an error. A similar mechanism is already implemented in hrana-over-http. -pub async fn garbage_collect( - clients: &mut HashMap>>, -) { +pub async fn garbage_collect(clients: &mut HashMap>) { let limit = std::time::Duration::from_secs(30); clients.retain(|_, db| db.idle_time() < limit); - tracing::trace!("gc: remaining client handles: {:?}", clients); + tracing::trace!("gc: remaining client handles count: {}", clients.len()); } #[tonic::async_trait]