diff --git a/sqld/src/connection/libsql.rs b/sqld/src/connection/libsql.rs index e8af119a..dcb8a04c 100644 --- a/sqld/src/connection/libsql.rs +++ b/sqld/src/connection/libsql.rs @@ -28,7 +28,7 @@ pub struct LibSqlDbFactory { db_path: PathBuf, hook: &'static WalMethodsHook, ctx_builder: Box W::Context + Sync + Send + 'static>, - stats: Stats, + stats: Arc, config_store: Arc, extensions: Arc<[PathBuf]>, max_response_size: u64, @@ -49,7 +49,7 @@ where db_path: PathBuf, hook: &'static WalMethodsHook, ctx_builder: F, - stats: Stats, + stats: Arc, config_store: Arc, extensions: Arc<[PathBuf]>, max_response_size: u64, @@ -168,7 +168,7 @@ impl LibSqlConnection { extensions: Arc<[PathBuf]>, wal_hook: &'static WalMethodsHook, hook_ctx: W::Context, - stats: Stats, + stats: Arc, config_store: Arc, builder_config: QueryBuilderConfig, ) -> crate::Result @@ -242,7 +242,7 @@ struct Connection<'a> { timeout_deadline: Option, conn: sqld_libsql_bindings::Connection<'a>, timed_out: bool, - stats: Stats, + stats: Arc, config_store: Arc, builder_config: QueryBuilderConfig, } @@ -253,7 +253,7 @@ impl<'a> Connection<'a> { extensions: Arc<[PathBuf]>, wal_methods: &'static WalMethodsHook, hook_ctx: &'a mut W::Context, - stats: Stats, + stats: Arc, config_store: Arc, builder_config: QueryBuilderConfig, ) -> Result { @@ -612,7 +612,7 @@ mod test { timeout_deadline: None, conn: sqld_libsql_bindings::Connection::test(ctx), timed_out: false, - stats: Stats::default(), + stats: Arc::new(Stats::default()), config_store: Arc::new(DatabaseConfigStore::new_test()), builder_config: QueryBuilderConfig::default(), }; diff --git a/sqld/src/connection/write_proxy.rs b/sqld/src/connection/write_proxy.rs index bd9fa85b..685da0d0 100644 --- a/sqld/src/connection/write_proxy.rs +++ b/sqld/src/connection/write_proxy.rs @@ -37,7 +37,7 @@ pub struct MakeWriteProxyConnection { client: ProxyClient, db_path: PathBuf, extensions: Arc<[PathBuf]>, - stats: Stats, + stats: Arc, config_store: Arc, applied_frame_no_receiver: watch::Receiver, max_response_size: u64, @@ -52,7 +52,7 @@ impl MakeWriteProxyConnection { extensions: Arc<[PathBuf]>, channel: Channel, uri: tonic::transport::Uri, - stats: Stats, + stats: Arc, config_store: Arc, applied_frame_no_receiver: watch::Receiver, max_response_size: u64, @@ -110,7 +110,7 @@ pub struct WriteProxyConnection { /// Notifier from the repliator of the currently applied frameno applied_frame_no_receiver: watch::Receiver, builder_config: QueryBuilderConfig, - stats: Stats, + stats: Arc, /// bytes representing the namespace name namespace: Bytes, } @@ -166,7 +166,7 @@ impl WriteProxyConnection { write_proxy: ProxyClient, db_path: PathBuf, extensions: Arc<[PathBuf]>, - stats: Stats, + stats: Arc, config_store: Arc, applied_frame_no_receiver: watch::Receiver, builder_config: QueryBuilderConfig, diff --git a/sqld/src/heartbeat.rs b/sqld/src/heartbeat.rs index c15a8d34..e72621a6 100644 --- a/sqld/src/heartbeat.rs +++ b/sqld/src/heartbeat.rs @@ -1,28 +1,62 @@ +use bytes::Bytes; +use std::collections::HashMap; +use std::sync::Weak; use std::time::Duration; -use tokio::time::sleep; +use tokio::sync::mpsc; +use url::Url; -use crate::http::stats::StatsResponse; +use crate::http::admin::stats::StatsResponse; use crate::stats::Stats; pub async fn server_heartbeat( - url: String, + url: Url, auth: Option, update_period: Duration, - stats: Stats, + mut stats_subs: mpsc::Receiver<(Bytes, Weak)>, ) { + let mut watched = HashMap::new(); let client = reqwest::Client::new(); + let mut interval = tokio::time::interval(update_period); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); loop { - sleep(update_period).await; - let body = StatsResponse::from(&stats); - let request = client.post(&url); - let request = if let Some(ref auth) = auth { - request.header("Authorization", auth.clone()) - } else { - request + tokio::select! { + Some((ns, stats)) = stats_subs.recv() => { + watched.insert(ns, stats); + } + _ = interval.tick() => { + send_stats(&mut watched, &client, &url, auth.as_deref()).await; + } }; - let request = request.json(&body); - if let Err(err) = request.send().await { - tracing::warn!("Error sending heartbeat: {}", err); + } +} + +async fn send_stats( + watched: &mut HashMap>, + client: &reqwest::Client, + url: &Url, + auth: Option<&str>, +) { + // first send all the stats... + for (ns, stats) in watched.iter() { + if let Some(stats) = stats.upgrade() { + let body = StatsResponse::from(stats.as_ref()); + let mut url = url.clone(); + url.path_segments_mut() + .unwrap() + .push(std::str::from_utf8(ns).unwrap()); + let request = client.post(url); + let request = if let Some(ref auth) = auth { + request.header("Authorization", auth.to_string()) + } else { + request + }; + let request = request.json(&body); + if let Err(err) = request.send().await { + tracing::warn!("Error sending heartbeat: {}", err); + } } } + + // ..and then remove all expired subscription + watched.retain(|_, s| s.upgrade().is_some()); } diff --git a/sqld/src/hrana/ws/handshake.rs b/sqld/src/hrana/ws/handshake.rs index 46580f88..5df73a63 100644 --- a/sqld/src/hrana/ws/handshake.rs +++ b/sqld/src/hrana/ws/handshake.rs @@ -4,7 +4,7 @@ use futures::{SinkExt as _, StreamExt as _}; use tokio_tungstenite::tungstenite; use tungstenite::http; -use crate::http::db_factory::namespace_from_headers; +use crate::http::user::db_factory::namespace_from_headers; use crate::net::Conn; use super::super::{Encoding, Version}; diff --git a/sqld/src/admin_api.rs b/sqld/src/http/admin/mod.rs similarity index 97% rename from sqld/src/admin_api.rs rename to sqld/src/http/admin/mod.rs index 28b0b4f3..9476e5ba 100644 --- a/sqld/src/admin_api.rs +++ b/sqld/src/http/admin/mod.rs @@ -15,12 +15,14 @@ use crate::connection::config::{DatabaseConfig, DatabaseConfigStore}; use crate::error::LoadDumpError; use crate::namespace::{DumpStream, MakeNamespace, NamespaceStore, RestoreOption}; +pub mod stats; + struct AppState { db_config_store: Arc, namespaces: NamespaceStore, } -pub async fn run_admin_api( +pub async fn run( acceptor: A, db_config_store: Arc, namespaces: NamespaceStore, @@ -47,6 +49,7 @@ where post(handle_restore_namespace), ) .route("/v1/namespaces/:namespace", delete(handle_delete_namespace)) + .route("/v1/namespaces/:namespace/stats", get(stats::handle_stats)) .with_state(Arc::new(AppState { db_config_store, namespaces, diff --git a/sqld/src/http/stats.rs b/sqld/src/http/admin/stats.rs similarity index 53% rename from sqld/src/http/stats.rs rename to sqld/src/http/admin/stats.rs index 245aa2a2..9bfaec11 100644 --- a/sqld/src/http/stats.rs +++ b/sqld/src/http/admin/stats.rs @@ -1,9 +1,12 @@ -use hyper::{Body, Response}; +use std::sync::Arc; + use serde::Serialize; -use axum::extract::{FromRef, State as AxumState}; +use axum::extract::{Path, State}; +use axum::Json; -use crate::{namespace::MakeNamespace, stats::Stats}; +use crate::namespace::MakeNamespace; +use crate::stats::Stats; use super::AppState; @@ -32,18 +35,12 @@ impl From for StatsResponse { } } -impl FromRef> for Stats { - fn from_ref(input: &AppState) -> Self { - input.stats.clone() - } -} - -pub(crate) async fn handle_stats(AxumState(stats): AxumState) -> Response { - let resp: StatsResponse = stats.into(); +pub(super) async fn handle_stats( + State(app_state): State>>, + Path(namespace): Path, +) -> crate::Result> { + let stats = app_state.namespaces.stats(namespace.into()).await?; + let resp: StatsResponse = stats.as_ref().into(); - let payload = serde_json::to_vec(&resp).unwrap(); - Response::builder() - .header("Content-Type", "application/json") - .body(Body::from(payload)) - .unwrap() + Ok(Json(resp)) } diff --git a/sqld/src/http/mod.rs b/sqld/src/http/mod.rs index 64533bce..1e6bf65b 100644 --- a/sqld/src/http/mod.rs +++ b/sqld/src/http/mod.rs @@ -1,481 +1,2 @@ -pub mod db_factory; -mod dump; -mod hrana_over_http_1; -mod result_builder; -pub mod stats; -mod types; - -use std::path::Path; -use std::sync::Arc; - -use anyhow::Context; -use axum::extract::{FromRef, FromRequest, FromRequestParts, State as AxumState}; -use axum::http::request::Parts; -use axum::http::HeaderValue; -use axum::response::{Html, IntoResponse}; -use axum::routing::{get, post}; -use axum::Router; -use axum_extra::middleware::option_layer; -use base64::prelude::BASE64_STANDARD_NO_PAD; -use base64::Engine; -use hyper::{header, Body, Request, Response, StatusCode}; -use serde::de::DeserializeOwned; -use serde::Serialize; -use serde_json::Number; -use tokio::sync::{mpsc, oneshot}; -use tokio::task::JoinSet; -use tonic::transport::Server; -use tower_http::trace::DefaultOnResponse; -use tower_http::{compression::CompressionLayer, cors}; -use tracing::{Level, Span}; - -use crate::auth::{Auth, Authenticated}; -use crate::connection::Connection; -use crate::database::Database; -use crate::error::Error; -use crate::hrana; -use crate::http::types::HttpQuery; -use crate::namespace::{MakeNamespace, NamespaceStore}; -use crate::net::Accept; -use crate::query::{self, Query}; -use crate::query_analysis::{predict_final_state, State, Statement}; -use crate::query_result_builder::QueryResultBuilder; -use crate::rpc::proxy::rpc::proxy_server::{Proxy, ProxyServer}; -use crate::rpc::replication_log::rpc::replication_log_server::ReplicationLog; -use crate::rpc::ReplicationLogServer; -use crate::stats::Stats; -use crate::utils::services::idle_shutdown::IdleShutdownKicker; -use crate::version; - -use self::db_factory::MakeConnectionExtractor; -use self::result_builder::JsonHttpPayloadBuilder; -use self::types::QueryObject; - -impl TryFrom for serde_json::Value { - type Error = Error; - - fn try_from(value: query::Value) -> Result { - let value = match value { - query::Value::Null => serde_json::Value::Null, - query::Value::Integer(i) => serde_json::Value::Number(Number::from(i)), - query::Value::Real(x) => { - serde_json::Value::Number(Number::from_f64(x).ok_or_else(|| { - Error::DbValueError(format!( - "Cannot to convert database value `{x}` to a JSON number" - )) - })?) - } - query::Value::Text(s) => serde_json::Value::String(s), - query::Value::Blob(v) => serde_json::json!({ - "base64": BASE64_STANDARD_NO_PAD.encode(v), - }), - }; - - Ok(value) - } -} - -/// Encodes a query response rows into json -#[derive(Debug, Serialize)] -struct RowsResponse { - columns: Vec, - rows: Vec>, -} - -fn parse_queries(queries: Vec) -> crate::Result> { - let mut out = Vec::with_capacity(queries.len()); - for query in queries { - let mut iter = Statement::parse(&query.q); - let stmt = iter.next().transpose()?.unwrap_or_default(); - if iter.next().is_some() { - return Err(Error::FailedToParse("found more than one command in a single statement string. It is allowed to issue only one command per string.".to_string())); - } - let query = Query { - stmt, - params: query.params.0, - want_rows: true, - }; - - out.push(query); - } - - match predict_final_state(State::Init, out.iter().map(|q| &q.stmt)) { - State::Txn => { - return Err(Error::QueryError( - "interactive transaction not allowed in HTTP queries".to_string(), - )) - } - State::Init => (), - // maybe we should err here, but let's sqlite deal with that. - State::Invalid => (), - } - - Ok(out) -} - -async fn handle_query( - auth: Authenticated, - MakeConnectionExtractor(connection_maker): MakeConnectionExtractor, - Json(query): Json, -) -> Result { - let batch = parse_queries(query.statements)?; - - let db = connection_maker.create().await?; - - let builder = JsonHttpPayloadBuilder::new(); - let (builder, _) = db.execute_batch_or_rollback(batch, auth, builder).await?; - - let res = ( - [(header::CONTENT_TYPE, "application/json")], - builder.into_ret(), - ); - Ok(res.into_response()) -} - -async fn show_console( - AxumState(AppState { enable_console, .. }): AxumState>, -) -> impl IntoResponse { - if enable_console { - Html(std::include_str!("console.html")).into_response() - } else { - StatusCode::NOT_FOUND.into_response() - } -} - -async fn handle_health() -> Response { - // return empty OK - Response::new(Body::empty()) -} - -async fn handle_upgrade( - AxumState(AppState { upgrade_tx, .. }): AxumState>, - req: Request, -) -> impl IntoResponse { - if !hyper_tungstenite::is_upgrade_request(&req) { - return StatusCode::NOT_FOUND.into_response(); - } - - let (response_tx, response_rx) = oneshot::channel(); - let _: Result<_, _> = upgrade_tx - .send(hrana::ws::Upgrade { - request: req, - response_tx, - }) - .await; - - match response_rx.await { - Ok(response) => response.into_response(), - Err(_) => ( - StatusCode::SERVICE_UNAVAILABLE, - "sqld was not able to process the HTTP upgrade", - ) - .into_response(), - } -} - -async fn handle_version() -> Response { - let version = version::version(); - Response::new(Body::from(version)) -} - -async fn handle_fallback() -> impl IntoResponse { - (StatusCode::NOT_FOUND).into_response() -} - -/// Router wide state that each request has access too via -/// axum's `State` extractor. -pub(crate) struct AppState { - auth: Arc, - namespaces: NamespaceStore, - upgrade_tx: mpsc::Sender, - hrana_http_srv: Arc::Connection>>, - enable_console: bool, - stats: Stats, - disable_default_namespace: bool, - disable_namespaces: bool, - path: Arc, -} - -impl Clone for AppState { - fn clone(&self) -> Self { - Self { - auth: self.auth.clone(), - namespaces: self.namespaces.clone(), - upgrade_tx: self.upgrade_tx.clone(), - hrana_http_srv: self.hrana_http_srv.clone(), - enable_console: self.enable_console, - stats: self.stats.clone(), - disable_default_namespace: self.disable_default_namespace, - disable_namespaces: self.disable_namespaces, - path: self.path.clone(), - } - } -} - -pub struct UserApi { - pub auth: Arc, - pub http_acceptor: Option, - pub hrana_ws_acceptor: Option, - pub namespaces: NamespaceStore, - pub idle_shutdown_kicker: Option, - pub stats: Stats, - pub proxy_service: P, - pub replication_service: S, - pub disable_default_namespace: bool, - pub disable_namespaces: bool, - pub max_response_size: u64, - pub enable_console: bool, - pub self_url: Option, - pub path: Arc, -} - -impl UserApi -where - M: MakeNamespace, - A: Accept, - P: Proxy, - S: ReplicationLog, -{ - pub fn configure(self, join_set: &mut JoinSet>) { - let (hrana_accept_tx, hrana_accept_rx) = mpsc::channel(8); - let (hrana_upgrade_tx, hrana_upgrade_rx) = mpsc::channel(8); - let hrana_http_srv = Arc::new(hrana::http::Server::new(self.self_url.clone())); - - join_set.spawn({ - let namespaces = self.namespaces.clone(); - let auth = self.auth.clone(); - let idle_kicker = self - .idle_shutdown_kicker - .clone() - .map(|isl| isl.into_kicker()); - let disable_default_namespace = self.disable_default_namespace; - let disable_namespaces = self.disable_namespaces; - let max_response_size = self.max_response_size; - async move { - hrana::ws::serve( - auth, - idle_kicker, - max_response_size, - hrana_accept_rx, - hrana_upgrade_rx, - namespaces, - disable_default_namespace, - disable_namespaces, - ) - .await - .context("Hrana server failed") - } - }); - - join_set.spawn({ - let server = hrana_http_srv.clone(); - async move { - server.run_expire().await; - Ok(()) - } - }); - - if let Some(acceptor) = self.hrana_ws_acceptor { - join_set.spawn(async move { - hrana::ws::listen(acceptor, hrana_accept_tx).await; - Ok(()) - }); - } - - if let Some(acceptor) = self.http_acceptor { - let state = AppState { - auth: self.auth, - upgrade_tx: hrana_upgrade_tx, - hrana_http_srv, - enable_console: self.enable_console, - stats: self.stats.clone(), - namespaces: self.namespaces, - disable_default_namespace: self.disable_default_namespace, - disable_namespaces: self.disable_namespaces, - path: self.path, - }; - - fn trace_request(req: &Request, _span: &Span) { - tracing::debug!("got request: {} {}", req.method(), req.uri()); - } - - macro_rules! handle_hrana { - ($endpoint:expr, $version:expr, $encoding:expr,) => {{ - async fn handle_hrana( - AxumState(state): AxumState>, - MakeConnectionExtractor(connection_maker): MakeConnectionExtractor< - ::Connection, - >, - auth: Authenticated, - req: Request, - ) -> Result, Error> { - Ok(state - .hrana_http_srv - .handle_request( - connection_maker, - auth, - req, - $endpoint, - $version, - $encoding, - ) - .await?) - } - handle_hrana - }}; - } - - let app = Router::new() - .route("/", post(handle_query)) - .route("/", get(handle_upgrade)) - .route("/version", get(handle_version)) - .route("/console", get(show_console)) - .route("/health", get(handle_health)) - .route("/dump", get(dump::handle_dump)) - .route("/v1/stats", get(stats::handle_stats)) - .route("/v1", get(hrana_over_http_1::handle_index)) - .route("/v1/execute", post(hrana_over_http_1::handle_execute)) - .route("/v1/batch", post(hrana_over_http_1::handle_batch)) - .route("/v2", get(crate::hrana::http::handle_index)) - .route( - "/v2/pipeline", - post(handle_hrana!( - hrana::http::Endpoint::Pipeline, - hrana::Version::Hrana2, - hrana::Encoding::Json, - )), - ) - .route("/v3", get(crate::hrana::http::handle_index)) - .route( - "/v3/pipeline", - post(handle_hrana!( - hrana::http::Endpoint::Pipeline, - hrana::Version::Hrana3, - hrana::Encoding::Json, - )), - ) - .route( - "/v3/cursor", - post(handle_hrana!( - hrana::http::Endpoint::Cursor, - hrana::Version::Hrana3, - hrana::Encoding::Json, - )), - ) - .route("/v3-protobuf", get(crate::hrana::http::handle_index)) - .route( - "/v3-protobuf/pipeline", - post(handle_hrana!( - hrana::http::Endpoint::Pipeline, - hrana::Version::Hrana3, - hrana::Encoding::Protobuf, - )), - ) - .route( - "/v3-protobuf/cursor", - post(handle_hrana!( - hrana::http::Endpoint::Cursor, - hrana::Version::Hrana3, - hrana::Encoding::Protobuf, - )), - ) - .with_state(state); - - let layered_app = app - .layer(option_layer(self.idle_shutdown_kicker.clone())) - .layer( - tower_http::trace::TraceLayer::new_for_http() - .on_request(trace_request) - .on_response( - DefaultOnResponse::new() - .level(Level::DEBUG) - .latency_unit(tower_http::LatencyUnit::Micros), - ), - ) - .layer(CompressionLayer::new()) - .layer( - cors::CorsLayer::new() - .allow_methods(cors::AllowMethods::any()) - .allow_headers(cors::Any) - .allow_origin(cors::Any), - ); - - // Merge the grpc based axum router into our regular http router - let replication = ReplicationLogServer::new(self.replication_service); - let write_proxy = ProxyServer::new(self.proxy_service); - - let grpc_router = Server::builder() - .accept_http1(true) - .add_service(tonic_web::enable(replication)) - .add_service(tonic_web::enable(write_proxy)) - .into_router(); - - let router = layered_app.merge(grpc_router); - - let router = router.fallback(handle_fallback); - let h2c = crate::h2c::H2cMaker::new(router); - - join_set.spawn(async move { - hyper::server::Server::builder(acceptor) - .serve(h2c) - .await - .context("http server")?; - Ok(()) - }); - } - } -} - -/// Axum authenticated extractor -#[tonic::async_trait] -impl FromRequestParts for Authenticated -where - Arc: FromRef, - S: Send + Sync, -{ - type Rejection = Error; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let auth = as FromRef>::from_ref(state); - - let auth_header = parts.headers.get(hyper::header::AUTHORIZATION); - let auth = auth.authenticate_http(auth_header)?; - - Ok(auth) - } -} - -impl FromRef> for Arc { - fn from_ref(input: &AppState) -> Self { - input.auth.clone() - } -} - -#[derive(Debug, Clone, Copy, Default)] -#[must_use] -pub struct Json(pub T); - -#[tonic::async_trait] -impl FromRequest for Json -where - T: DeserializeOwned, - B: hyper::body::HttpBody + Send + 'static, - B::Data: Send, - B::Error: Into>, - S: Send + Sync, -{ - type Rejection = axum::extract::rejection::JsonRejection; - - async fn from_request(mut req: Request, state: &S) -> Result { - let headers = req.headers_mut(); - - headers.insert( - header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - - axum::Json::from_request(req, state) - .await - .map(|t| Json(t.0)) - } -} +pub mod admin; +pub mod user; diff --git a/sqld/src/http/console.html b/sqld/src/http/user/console.html similarity index 100% rename from sqld/src/http/console.html rename to sqld/src/http/user/console.html diff --git a/sqld/src/http/db_factory.rs b/sqld/src/http/user/db_factory.rs similarity index 100% rename from sqld/src/http/db_factory.rs rename to sqld/src/http/user/db_factory.rs diff --git a/sqld/src/http/dump.rs b/sqld/src/http/user/dump.rs similarity index 100% rename from sqld/src/http/dump.rs rename to sqld/src/http/user/dump.rs diff --git a/sqld/src/http/hrana_over_http_1.rs b/sqld/src/http/user/hrana_over_http_1.rs similarity index 100% rename from sqld/src/http/hrana_over_http_1.rs rename to sqld/src/http/user/hrana_over_http_1.rs diff --git a/sqld/src/http/user/mod.rs b/sqld/src/http/user/mod.rs new file mode 100644 index 00000000..d174889a --- /dev/null +++ b/sqld/src/http/user/mod.rs @@ -0,0 +1,474 @@ +pub mod db_factory; +mod dump; +mod hrana_over_http_1; +mod result_builder; +mod types; + +use std::path::Path; +use std::sync::Arc; + +use anyhow::Context; +use axum::extract::{FromRef, FromRequest, FromRequestParts, State as AxumState}; +use axum::http::request::Parts; +use axum::http::HeaderValue; +use axum::response::{Html, IntoResponse}; +use axum::routing::{get, post}; +use axum::Router; +use axum_extra::middleware::option_layer; +use base64::prelude::BASE64_STANDARD_NO_PAD; +use base64::Engine; +use hyper::{header, Body, Request, Response, StatusCode}; +use serde::de::DeserializeOwned; +use serde::Serialize; +use serde_json::Number; +use tokio::sync::{mpsc, oneshot}; +use tokio::task::JoinSet; +use tonic::transport::Server; +use tower_http::trace::DefaultOnResponse; +use tower_http::{compression::CompressionLayer, cors}; +use tracing::{Level, Span}; + +use crate::auth::{Auth, Authenticated}; +use crate::connection::Connection; +use crate::database::Database; +use crate::error::Error; +use crate::hrana; +use crate::http::user::types::HttpQuery; +use crate::namespace::{MakeNamespace, NamespaceStore}; +use crate::net::Accept; +use crate::query::{self, Query}; +use crate::query_analysis::{predict_final_state, State, Statement}; +use crate::query_result_builder::QueryResultBuilder; +use crate::rpc::proxy::rpc::proxy_server::{Proxy, ProxyServer}; +use crate::rpc::replication_log::rpc::replication_log_server::ReplicationLog; +use crate::rpc::ReplicationLogServer; +use crate::utils::services::idle_shutdown::IdleShutdownKicker; +use crate::version; + +use self::db_factory::MakeConnectionExtractor; +use self::result_builder::JsonHttpPayloadBuilder; +use self::types::QueryObject; + +impl TryFrom for serde_json::Value { + type Error = Error; + + fn try_from(value: query::Value) -> Result { + let value = match value { + query::Value::Null => serde_json::Value::Null, + query::Value::Integer(i) => serde_json::Value::Number(Number::from(i)), + query::Value::Real(x) => { + serde_json::Value::Number(Number::from_f64(x).ok_or_else(|| { + Error::DbValueError(format!( + "Cannot to convert database value `{x}` to a JSON number" + )) + })?) + } + query::Value::Text(s) => serde_json::Value::String(s), + query::Value::Blob(v) => serde_json::json!({ + "base64": BASE64_STANDARD_NO_PAD.encode(v), + }), + }; + + Ok(value) + } +} + +/// Encodes a query response rows into json +#[derive(Debug, Serialize)] +struct RowsResponse { + columns: Vec, + rows: Vec>, +} + +fn parse_queries(queries: Vec) -> crate::Result> { + let mut out = Vec::with_capacity(queries.len()); + for query in queries { + let mut iter = Statement::parse(&query.q); + let stmt = iter.next().transpose()?.unwrap_or_default(); + if iter.next().is_some() { + return Err(Error::FailedToParse("found more than one command in a single statement string. It is allowed to issue only one command per string.".to_string())); + } + let query = Query { + stmt, + params: query.params.0, + want_rows: true, + }; + + out.push(query); + } + + match predict_final_state(State::Init, out.iter().map(|q| &q.stmt)) { + State::Txn => { + return Err(Error::QueryError( + "interactive transaction not allowed in HTTP queries".to_string(), + )) + } + State::Init => (), + // maybe we should err here, but let's sqlite deal with that. + State::Invalid => (), + } + + Ok(out) +} + +async fn handle_query( + auth: Authenticated, + MakeConnectionExtractor(connection_maker): MakeConnectionExtractor, + Json(query): Json, +) -> Result { + let batch = parse_queries(query.statements)?; + + let db = connection_maker.create().await?; + + let builder = JsonHttpPayloadBuilder::new(); + let (builder, _) = db.execute_batch_or_rollback(batch, auth, builder).await?; + + let res = ( + [(header::CONTENT_TYPE, "application/json")], + builder.into_ret(), + ); + Ok(res.into_response()) +} + +async fn show_console( + AxumState(AppState { enable_console, .. }): AxumState>, +) -> impl IntoResponse { + if enable_console { + Html(std::include_str!("console.html")).into_response() + } else { + StatusCode::NOT_FOUND.into_response() + } +} + +async fn handle_health() -> Response { + // return empty OK + Response::new(Body::empty()) +} + +async fn handle_upgrade( + AxumState(AppState { upgrade_tx, .. }): AxumState>, + req: Request, +) -> impl IntoResponse { + if !hyper_tungstenite::is_upgrade_request(&req) { + return StatusCode::NOT_FOUND.into_response(); + } + + let (response_tx, response_rx) = oneshot::channel(); + let _: Result<_, _> = upgrade_tx + .send(hrana::ws::Upgrade { + request: req, + response_tx, + }) + .await; + + match response_rx.await { + Ok(response) => response.into_response(), + Err(_) => ( + StatusCode::SERVICE_UNAVAILABLE, + "sqld was not able to process the HTTP upgrade", + ) + .into_response(), + } +} + +async fn handle_version() -> Response { + let version = version::version(); + Response::new(Body::from(version)) +} + +async fn handle_fallback() -> impl IntoResponse { + (StatusCode::NOT_FOUND).into_response() +} + +/// Router wide state that each request has access too via +/// axum's `State` extractor. +pub(crate) struct AppState { + auth: Arc, + namespaces: NamespaceStore, + upgrade_tx: mpsc::Sender, + hrana_http_srv: Arc::Connection>>, + enable_console: bool, + disable_default_namespace: bool, + disable_namespaces: bool, + path: Arc, +} + +impl Clone for AppState { + fn clone(&self) -> Self { + Self { + auth: self.auth.clone(), + namespaces: self.namespaces.clone(), + upgrade_tx: self.upgrade_tx.clone(), + hrana_http_srv: self.hrana_http_srv.clone(), + enable_console: self.enable_console, + disable_default_namespace: self.disable_default_namespace, + disable_namespaces: self.disable_namespaces, + path: self.path.clone(), + } + } +} + +pub struct UserApi { + pub auth: Arc, + pub http_acceptor: Option, + pub hrana_ws_acceptor: Option, + pub namespaces: NamespaceStore, + pub idle_shutdown_kicker: Option, + pub proxy_service: P, + pub replication_service: S, + pub disable_default_namespace: bool, + pub disable_namespaces: bool, + pub max_response_size: u64, + pub enable_console: bool, + pub self_url: Option, + pub path: Arc, +} + +impl UserApi +where + M: MakeNamespace, + A: Accept, + P: Proxy, + S: ReplicationLog, +{ + pub fn configure(self, join_set: &mut JoinSet>) { + let (hrana_accept_tx, hrana_accept_rx) = mpsc::channel(8); + let (hrana_upgrade_tx, hrana_upgrade_rx) = mpsc::channel(8); + let hrana_http_srv = Arc::new(hrana::http::Server::new(self.self_url.clone())); + + join_set.spawn({ + let namespaces = self.namespaces.clone(); + let auth = self.auth.clone(); + let idle_kicker = self + .idle_shutdown_kicker + .clone() + .map(|isl| isl.into_kicker()); + let disable_default_namespace = self.disable_default_namespace; + let disable_namespaces = self.disable_namespaces; + let max_response_size = self.max_response_size; + async move { + hrana::ws::serve( + auth, + idle_kicker, + max_response_size, + hrana_accept_rx, + hrana_upgrade_rx, + namespaces, + disable_default_namespace, + disable_namespaces, + ) + .await + .context("Hrana server failed") + } + }); + + join_set.spawn({ + let server = hrana_http_srv.clone(); + async move { + server.run_expire().await; + Ok(()) + } + }); + + if let Some(acceptor) = self.hrana_ws_acceptor { + join_set.spawn(async move { + hrana::ws::listen(acceptor, hrana_accept_tx).await; + Ok(()) + }); + } + + if let Some(acceptor) = self.http_acceptor { + let state = AppState { + auth: self.auth, + upgrade_tx: hrana_upgrade_tx, + hrana_http_srv, + enable_console: self.enable_console, + namespaces: self.namespaces, + disable_default_namespace: self.disable_default_namespace, + disable_namespaces: self.disable_namespaces, + path: self.path, + }; + + fn trace_request(req: &Request, _span: &Span) { + tracing::debug!("got request: {} {}", req.method(), req.uri()); + } + + macro_rules! handle_hrana { + ($endpoint:expr, $version:expr, $encoding:expr,) => {{ + async fn handle_hrana( + AxumState(state): AxumState>, + MakeConnectionExtractor(connection_maker): MakeConnectionExtractor< + ::Connection, + >, + auth: Authenticated, + req: Request, + ) -> Result, Error> { + Ok(state + .hrana_http_srv + .handle_request( + connection_maker, + auth, + req, + $endpoint, + $version, + $encoding, + ) + .await?) + } + handle_hrana + }}; + } + + let app = Router::new() + .route("/", post(handle_query)) + .route("/", get(handle_upgrade)) + .route("/version", get(handle_version)) + .route("/console", get(show_console)) + .route("/health", get(handle_health)) + .route("/dump", get(dump::handle_dump)) + .route("/v1", get(hrana_over_http_1::handle_index)) + .route("/v1/execute", post(hrana_over_http_1::handle_execute)) + .route("/v1/batch", post(hrana_over_http_1::handle_batch)) + .route("/v2", get(crate::hrana::http::handle_index)) + .route( + "/v2/pipeline", + post(handle_hrana!( + hrana::http::Endpoint::Pipeline, + hrana::Version::Hrana2, + hrana::Encoding::Json, + )), + ) + .route("/v3", get(crate::hrana::http::handle_index)) + .route( + "/v3/pipeline", + post(handle_hrana!( + hrana::http::Endpoint::Pipeline, + hrana::Version::Hrana3, + hrana::Encoding::Json, + )), + ) + .route( + "/v3/cursor", + post(handle_hrana!( + hrana::http::Endpoint::Cursor, + hrana::Version::Hrana3, + hrana::Encoding::Json, + )), + ) + .route("/v3-protobuf", get(crate::hrana::http::handle_index)) + .route( + "/v3-protobuf/pipeline", + post(handle_hrana!( + hrana::http::Endpoint::Pipeline, + hrana::Version::Hrana3, + hrana::Encoding::Protobuf, + )), + ) + .route( + "/v3-protobuf/cursor", + post(handle_hrana!( + hrana::http::Endpoint::Cursor, + hrana::Version::Hrana3, + hrana::Encoding::Protobuf, + )), + ) + .with_state(state); + + let layered_app = app + .layer(option_layer(self.idle_shutdown_kicker.clone())) + .layer( + tower_http::trace::TraceLayer::new_for_http() + .on_request(trace_request) + .on_response( + DefaultOnResponse::new() + .level(Level::DEBUG) + .latency_unit(tower_http::LatencyUnit::Micros), + ), + ) + .layer(CompressionLayer::new()) + .layer( + cors::CorsLayer::new() + .allow_methods(cors::AllowMethods::any()) + .allow_headers(cors::Any) + .allow_origin(cors::Any), + ); + + // Merge the grpc based axum router into our regular http router + let replication = ReplicationLogServer::new(self.replication_service); + let write_proxy = ProxyServer::new(self.proxy_service); + + let grpc_router = Server::builder() + .accept_http1(true) + .add_service(tonic_web::enable(replication)) + .add_service(tonic_web::enable(write_proxy)) + .into_router(); + + let router = layered_app.merge(grpc_router); + + let router = router.fallback(handle_fallback); + let h2c = crate::h2c::H2cMaker::new(router); + + join_set.spawn(async move { + hyper::server::Server::builder(acceptor) + .serve(h2c) + .await + .context("http server")?; + Ok(()) + }); + } + } +} + +/// Axum authenticated extractor +#[tonic::async_trait] +impl FromRequestParts for Authenticated +where + Arc: FromRef, + S: Send + Sync, +{ + type Rejection = Error; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let auth = as FromRef>::from_ref(state); + + let auth_header = parts.headers.get(hyper::header::AUTHORIZATION); + let auth = auth.authenticate_http(auth_header)?; + + Ok(auth) + } +} + +impl FromRef> for Arc { + fn from_ref(input: &AppState) -> Self { + input.auth.clone() + } +} + +#[derive(Debug, Clone, Copy, Default)] +#[must_use] +pub struct Json(pub T); + +#[tonic::async_trait] +impl FromRequest for Json +where + T: DeserializeOwned, + B: hyper::body::HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into>, + S: Send + Sync, +{ + type Rejection = axum::extract::rejection::JsonRejection; + + async fn from_request(mut req: Request, state: &S) -> Result { + let headers = req.headers_mut(); + + headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + + axum::Json::from_request(req, state) + .await + .map(|t| Json(t.0)) + } +} diff --git a/sqld/src/http/result_builder.rs b/sqld/src/http/user/result_builder.rs similarity index 100% rename from sqld/src/http/result_builder.rs rename to sqld/src/http/user/result_builder.rs diff --git a/sqld/src/http/snapshots/sqld__http__types__test__parse_http_query.snap b/sqld/src/http/user/snapshots/sqld__http__user__types__test__parse_http_query.snap similarity index 93% rename from sqld/src/http/snapshots/sqld__http__types__test__parse_http_query.snap rename to sqld/src/http/user/snapshots/sqld__http__user__types__test__parse_http_query.snap index 378769cb..500072d2 100644 --- a/sqld/src/http/snapshots/sqld__http__types__test__parse_http_query.snap +++ b/sqld/src/http/user/snapshots/sqld__http__user__types__test__parse_http_query.snap @@ -1,5 +1,5 @@ --- -source: sqld/src/http/types.rs +source: sqld/src/http/user/types.rs expression: found --- { diff --git a/sqld/src/http/snapshots/sqld__http__types__test__parse_named_params.snap b/sqld/src/http/user/snapshots/sqld__http__user__types__test__parse_named_params.snap similarity index 90% rename from sqld/src/http/snapshots/sqld__http__types__test__parse_named_params.snap rename to sqld/src/http/user/snapshots/sqld__http__user__types__test__parse_named_params.snap index 9cd03246..542c7193 100644 --- a/sqld/src/http/snapshots/sqld__http__types__test__parse_named_params.snap +++ b/sqld/src/http/user/snapshots/sqld__http__user__types__test__parse_named_params.snap @@ -1,5 +1,5 @@ --- -source: sqld/src/http/types.rs +source: sqld/src/http/user/types.rs expression: found --- { diff --git a/sqld/src/http/snapshots/sqld__http__types__test__parse_positional_params.snap b/sqld/src/http/user/snapshots/sqld__http__user__types__test__parse_positional_params.snap similarity index 88% rename from sqld/src/http/snapshots/sqld__http__types__test__parse_positional_params.snap rename to sqld/src/http/user/snapshots/sqld__http__user__types__test__parse_positional_params.snap index f4816833..50b7faba 100644 --- a/sqld/src/http/snapshots/sqld__http__types__test__parse_positional_params.snap +++ b/sqld/src/http/user/snapshots/sqld__http__user__types__test__parse_positional_params.snap @@ -1,5 +1,5 @@ --- -source: sqld/src/http/types.rs +source: sqld/src/http/user/types.rs expression: found --- { diff --git a/sqld/src/http/types.rs b/sqld/src/http/user/types.rs similarity index 100% rename from sqld/src/http/types.rs rename to sqld/src/http/user/types.rs diff --git a/sqld/src/lib.rs b/sqld/src/lib.rs index c3bd52c1..81848a44 100644 --- a/sqld/src/lib.rs +++ b/sqld/src/lib.rs @@ -2,8 +2,8 @@ use std::path::{Path, PathBuf}; use std::process::Command; -use std::sync::mpsc::RecvTimeoutError; -use std::sync::Arc; +use std::str::FromStr; +use std::sync::{Arc, Weak}; use std::time::Duration; use anyhow::Context as AnyhowContext; @@ -11,10 +11,8 @@ use bytes::Bytes; use config::{ AdminApiConfig, DbConfig, HeartbeatConfig, RpcClientConfig, RpcServerConfig, UserApiConfig, }; -use futures::never::Never; -use http::UserApi; +use http::user::UserApi; use hyper::client::HttpConnector; -use libsql::wal_hook::TRANSPARENT_METHODS; use namespace::{ MakeNamespace, NamespaceStore, PrimaryNamespaceConfig, PrimaryNamespaceMaker, ReplicaNamespaceConfig, ReplicaNamespaceMaker, @@ -28,13 +26,13 @@ use rpc::replication_log::rpc::replication_log_server::ReplicationLog; use rpc::replication_log::ReplicationLogService; use rpc::replication_log_proxy::ReplicationLogProxyService; use rpc::run_rpc_server; -use tokio::sync::Notify; +use tokio::sync::{mpsc, Notify}; use tokio::task::JoinSet; +use url::Url; use utils::services::idle_shutdown::IdleShutdownKicker; use crate::auth::Auth; use crate::connection::config::DatabaseConfigStore; -use crate::connection::libsql::open_db; use crate::connection::{Connection, MakeConnection}; use crate::error::Error; use crate::migration::maybe_migrate; @@ -49,7 +47,6 @@ pub mod net; pub mod rpc; pub mod version; -mod admin_api; mod auth; mod database; mod error; @@ -74,6 +71,7 @@ const DEFAULT_NAMESPACE_NAME: &str = "default"; const DEFAULT_AUTO_CHECKPOINT: u32 = 1000; type Result = std::result::Result; +type StatsSender = mpsc::Sender<(Bytes, Weak)>; pub struct Server { pub path: Arc, @@ -93,7 +91,6 @@ pub struct Server { struct Services { namespaces: NamespaceStore, idle_shutdown_kicker: Option, - stats: Stats, db_config_store: Arc, proxy_service: P, replication_service: S, @@ -120,7 +117,6 @@ where auth: self.auth, namespaces: self.namespaces.clone(), idle_shutdown_kicker: self.idle_shutdown_kicker.clone(), - stats: self.stats.clone(), proxy_service: self.proxy_service, replication_service: self.replication_service, disable_default_namespace: self.disable_default_namespace, @@ -134,7 +130,7 @@ where user_http.configure(join_set); if let Some(AdminApiConfig { acceptor }) = self.admin_api_config { - join_set.spawn(admin_api::run_admin_api( + join_set.spawn(http::admin::run( acceptor, self.db_config_store, self.namespaces, @@ -143,55 +139,6 @@ where } } -// Periodically check the storage used by the database and save it in the Stats structure. -// TODO: Once we have a separate fiber that does WAL checkpoints, running this routine -// right after checkpointing is exactly where it should be done. -async fn run_storage_monitor(db_path: Arc, stats: Stats) -> anyhow::Result<()> { - let (_drop_guard, exit_notify) = std::sync::mpsc::channel::(); - let _ = tokio::task::spawn_blocking(move || { - let duration = tokio::time::Duration::from_secs(60); - loop { - // because closing the last connection interferes with opening a new one, we lazily - // initialize a connection here, and keep it alive for the entirety of the program. If we - // fail to open it, we wait for `duration` and try again later. - let ctx = &mut (); - // We can safely open db with DEFAULT_AUTO_CHECKPOINT, since monitor is read-only: it - // won't produce new updates, frames or generate checkpoints. - let maybe_conn = match open_db(&db_path, &TRANSPARENT_METHODS, ctx, Some(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY), DEFAULT_AUTO_CHECKPOINT) { - Ok(conn) => Some(conn), - Err(e) => { - tracing::warn!("failed to open connection for storager monitor: {e}, trying again in {duration:?}"); - None - }, - }; - - loop { - if let Some(ref conn) = maybe_conn { - if let Ok(storage_bytes_used) = - conn.query_row("select sum(pgsize) from dbstat;", [], |row| { - row.get::(0) - }) - { - stats.set_storage_bytes_used(storage_bytes_used); - } - } - - match exit_notify.recv_timeout(duration) { - Ok(_) => unreachable!(), - Err(RecvTimeoutError::Disconnected) => return, - Err(RecvTimeoutError::Timeout) => (), - - } - if maybe_conn.is_none() { - break - } - } - } - }).await; - - Ok(()) -} - async fn run_periodic_checkpoint( connection_maker: Arc, period: Duration, @@ -319,7 +266,11 @@ where }) } - fn spawn_monitoring_tasks(&self, join_set: &mut JoinSet>, stats: Stats) { + fn spawn_monitoring_tasks( + &self, + join_set: &mut JoinSet>, + stats_receiver: mpsc::Receiver<(Bytes, Weak)>, + ) -> anyhow::Result<()> { match self.heartbeat_config { Some(ref config) => { tracing::info!( @@ -328,28 +279,30 @@ where config.heartbeat_period, ); join_set.spawn({ - let heartbeat_url = config.heartbeat_url.clone(); let heartbeat_auth = config.heartbeat_auth.clone(); let heartbeat_period = config.heartbeat_period; - let stats = stats.clone(); + let heartbeat_url = + Url::from_str(&config.heartbeat_url).context("invalid heartbeat URL")?; async move { heartbeat::server_heartbeat( heartbeat_url, heartbeat_auth, heartbeat_period, - stats, + stats_receiver, ) .await; Ok(()) } }); - join_set.spawn(run_storage_monitor(self.path.clone(), stats)); + // join_set.spawn(run_storage_monitor(self.path.clone(), stats)); } None => { tracing::warn!("No server heartbeat configured") } } + + Ok(()) } pub async fn start(self) -> anyhow::Result<()> { @@ -357,8 +310,8 @@ where init_version_file(&self.path)?; maybe_migrate(&self.path)?; - let stats = Stats::new(&self.path)?; - self.spawn_monitoring_tasks(&mut join_set, stats.clone()); + let (stats_sender, stats_receiver) = mpsc::channel(8); + self.spawn_monitoring_tasks(&mut join_set, stats_receiver)?; self.init_sqlite_globals(); let db_is_dirty = init_sentinel_file(&self.path)?; let idle_shutdown_kicker = self.setup_shutdown(); @@ -374,7 +327,7 @@ where Some(rpc_config) => { let replica = Replica { rpc_config, - stats: stats.clone(), + stats_sender, db_config_store: db_config_store.clone(), extensions, db_config: self.db_config.clone(), @@ -384,7 +337,6 @@ where let services = Services { namespaces, idle_shutdown_kicker, - stats, db_config_store, proxy_service, replication_service, @@ -404,7 +356,7 @@ where rpc_config: self.rpc_server_config, db_config: self.db_config.clone(), idle_shutdown_kicker: idle_shutdown_kicker.clone(), - stats: stats.clone(), + stats_sender, db_config_store: db_config_store.clone(), db_is_dirty, snapshot_callback, @@ -419,7 +371,6 @@ where let services = Services { namespaces, idle_shutdown_kicker, - stats, db_config_store, proxy_service, replication_service, @@ -463,7 +414,7 @@ struct Primary<'a, A> { rpc_config: Option>, db_config: DbConfig, idle_shutdown_kicker: Option, - stats: Stats, + stats_sender: StatsSender, db_config_store: Arc, db_is_dirty: bool, snapshot_callback: NamespacedSnapshotCallback, @@ -493,7 +444,7 @@ where snapshot_callback: self.snapshot_callback, bottomless_replication: self.db_config.bottomless_replication, extensions: self.extensions, - stats: self.stats, + stats_sender: self.stats_sender.clone(), config_store: self.db_config_store, max_response_size: self.db_config.max_response_size, max_total_response_size: self.db_config.max_total_response_size, @@ -539,7 +490,7 @@ where struct Replica { rpc_config: RpcClientConfig, - stats: Stats, + stats_sender: StatsSender, db_config_store: Arc, extensions: Arc<[PathBuf]>, db_config: DbConfig, @@ -560,7 +511,7 @@ impl Replica { channel: channel.clone(), uri: uri.clone(), extensions: self.extensions.clone(), - stats: self.stats.clone(), + stats_sender: self.stats_sender.clone(), config_store: self.db_config_store.clone(), base_path: self.base_path, max_response_size: self.db_config.max_response_size, diff --git a/sqld/src/namespace/mod.rs b/sqld/src/namespace/mod.rs index 15f632e6..f91b3df3 100644 --- a/sqld/src/namespace/mod.rs +++ b/sqld/src/namespace/mod.rs @@ -1,7 +1,7 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::path::{Path, PathBuf}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use std::time::Duration; use anyhow::{bail, Context as _}; @@ -14,6 +14,7 @@ use futures_core::future::BoxFuture; use futures_core::Stream; use hyper::Uri; use rusqlite::ErrorCode; +use sqld_libsql_bindings::wal_hook::TRANSPARENT_METHODS; use tokio::io::AsyncBufReadExt; use tokio::task::{block_in_place, JoinSet}; use tokio_util::io::StreamReader; @@ -31,8 +32,8 @@ use crate::replication::replica::Replicator; use crate::replication::{NamespacedSnapshotCallback, ReplicationLogger}; use crate::stats::Stats; use crate::{ - run_periodic_checkpoint, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT, DEFAULT_NAMESPACE_NAME, - MAX_CONCURRENT_DBS, + run_periodic_checkpoint, StatsSender, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT, + DEFAULT_NAMESPACE_NAME, MAX_CONCURRENT_DBS, }; pub use fork::ForkError; @@ -400,6 +401,10 @@ impl NamespaceStore { Ok(()) } + + pub async fn stats(&self, namespace: Bytes) -> crate::Result> { + self.with(namespace, |ns| ns.stats.clone()).await + } } /// A namspace isolates the resources pertaining to a database of type T @@ -408,6 +413,7 @@ pub struct Namespace { pub db: T, /// The set of tasks associated with this namespace tasks: JoinSet>, + stats: Arc, } impl Namespace { @@ -430,7 +436,7 @@ pub struct ReplicaNamespaceConfig { /// Extensions to load for the database connection pub extensions: Arc<[PathBuf]>, /// Stats monitor - pub stats: Stats, + pub stats_sender: StatsSender, /// Reference to the config store pub config_store: Arc, } @@ -465,6 +471,14 @@ impl Namespace { let applied_frame_no_receiver = replicator.current_frame_no_notifier.clone(); + let stats = make_stats( + &db_path, + &mut join_set, + config.stats_sender.clone(), + name.clone(), + ) + .await?; + join_set.spawn(replicator.run()); let connection_maker = MakeWriteProxyConnection::new( @@ -472,7 +486,7 @@ impl Namespace { config.extensions.clone(), config.channel.clone(), config.uri.clone(), - config.stats.clone(), + stats.clone(), config.config_store.clone(), applied_frame_no_receiver, config.max_response_size, @@ -490,6 +504,7 @@ impl Namespace { db: ReplicaDatabase { connection_maker: Arc::new(connection_maker), }, + stats, }) } } @@ -502,7 +517,7 @@ pub struct PrimaryNamespaceConfig { pub snapshot_callback: NamespacedSnapshotCallback, pub bottomless_replication: Option, pub extensions: Arc<[PathBuf]>, - pub stats: Stats, + pub stats_sender: StatsSender, pub config_store: Arc, pub max_response_size: u64, pub max_total_response_size: u64, @@ -599,11 +614,19 @@ impl Namespace { move || ReplicationLoggerHookCtx::new(logger.clone(), bottomless_replicator.clone()) }; + let stats = make_stats( + &db_path, + &mut join_set, + config.stats_sender.clone(), + name.clone(), + ) + .await?; + let connection_maker: Arc<_> = LibSqlDbFactory::new( db_path.clone(), &REPLICATION_METHODS, ctx_builder.clone(), - config.stats.clone(), + stats.clone(), config.config_store.clone(), config.extensions.clone(), config.max_response_size, @@ -646,10 +669,29 @@ impl Namespace { logger, connection_maker, }, + stats, }) } } +async fn make_stats( + db_path: &Path, + join_set: &mut JoinSet>, + stats_sender: StatsSender, + name: Bytes, +) -> anyhow::Result> { + let stats = Stats::new(db_path, join_set).await?; + + // the storage monitor is optional, so we ignore the error here. + let _ = stats_sender + .send((name.clone(), Arc::downgrade(&stats))) + .await; + + join_set.spawn(run_storage_monitor(db_path.into(), Arc::downgrade(&stats))); + + Ok(stats) +} + #[derive(Default)] pub enum RestoreOption { /// Restore database state from the most recent version found in a backup. @@ -799,3 +841,44 @@ fn check_fresh_db(path: &Path) -> crate::Result { let is_fresh = !path.join("wallog").try_exists()?; Ok(is_fresh) } + +// Periodically check the storage used by the database and save it in the Stats structure. +// TODO: Once we have a separate fiber that does WAL checkpoints, running this routine +// right after checkpointing is exactly where it should be done. +async fn run_storage_monitor(db_path: PathBuf, stats: Weak) -> anyhow::Result<()> { + // on initialization, the database file doesn't exist yet, so we wait a bit for it to be + // created + tokio::time::sleep(Duration::from_secs(1)).await; + + let duration = tokio::time::Duration::from_secs(60); + let db_path: Arc = db_path.into(); + loop { + let db_path = db_path.clone(); + let Some(stats) = stats.upgrade() else { return Ok(()) }; + let _ = tokio::task::spawn_blocking(move || { + // because closing the last connection interferes with opening a new one, we lazily + // initialize a connection here, and keep it alive for the entirety of the program. If we + // fail to open it, we wait for `duration` and try again later. + let ctx = &mut (); + // We can safely open db with DEFAULT_AUTO_CHECKPOINT, since monitor is read-only: it + // won't produce new updates, frames or generate checkpoints. + match open_db(&db_path, &TRANSPARENT_METHODS, ctx, Some(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY), DEFAULT_AUTO_CHECKPOINT) { + Ok(conn) => { + if let Ok(storage_bytes_used) = + conn.query_row("select sum(pgsize) from dbstat;", [], |row| { + row.get::(0) + }) + { + stats.set_storage_bytes_used(storage_bytes_used); + } + + }, + Err(e) => { + tracing::warn!("failed to open connection for storager monitor: {e}, trying again in {duration:?}"); + }, + } + }).await; + + tokio::time::sleep(duration).await; + } +} diff --git a/sqld/src/stats.rs b/sqld/src/stats.rs index 8f2138c1..0e966f75 100644 --- a/sqld/src/stats.rs +++ b/sqld/src/stats.rs @@ -1,19 +1,14 @@ -use std::fs::{File, OpenOptions}; -use std::io::Seek; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use std::time::Duration; use serde::{Deserialize, Serialize}; +use tokio::io::AsyncWriteExt; +use tokio::task::JoinSet; -#[derive(Clone, Default)] +#[derive(Debug, Default, Serialize, Deserialize)] pub struct Stats { - inner: Arc, -} - -#[derive(Serialize, Deserialize, Default)] -struct StatsInner { rows_written: AtomicU64, rows_read: AtomicU64, storage_bytes_used: AtomicU64, @@ -22,70 +17,85 @@ struct StatsInner { } impl Stats { - pub fn new(db_path: &Path) -> anyhow::Result { + pub async fn new( + db_path: &Path, + join_set: &mut JoinSet>, + ) -> anyhow::Result> { let stats_path = db_path.join("stats.json"); - let stats_file = OpenOptions::new() - .read(true) - .write(true) - .create(true) - .open(stats_path)?; - - let stats_inner = - serde_json::from_reader(&stats_file).unwrap_or_else(|_| StatsInner::default()); - let inner = Arc::new(stats_inner); - - spawn_stats_persist_thread(inner.clone(), stats_file); - - Ok(Self { inner }) + let this = if stats_path.try_exists()? { + let data = tokio::fs::read_to_string(&stats_path).await?; + Arc::new(serde_json::from_str(&data)?) + } else { + Arc::new(Stats::default()) + }; + + join_set.spawn(spawn_stats_persist_thread( + Arc::downgrade(&this), + stats_path.to_path_buf(), + )); + + Ok(this) } /// increments the number of written rows by n pub fn inc_rows_written(&self, n: u64) { - self.inner.rows_written.fetch_add(n, Ordering::Relaxed); + self.rows_written.fetch_add(n, Ordering::Relaxed); } /// increments the number of read rows by n pub fn inc_rows_read(&self, n: u64) { - self.inner.rows_read.fetch_add(n, Ordering::Relaxed); + self.rows_read.fetch_add(n, Ordering::Relaxed); } pub fn set_storage_bytes_used(&self, n: u64) { - self.inner.storage_bytes_used.store(n, Ordering::Relaxed); + self.storage_bytes_used.store(n, Ordering::Relaxed); } /// returns the total number of rows read since this database was created pub fn rows_read(&self) -> u64 { - self.inner.rows_read.load(Ordering::Relaxed) + self.rows_read.load(Ordering::Relaxed) } /// returns the total number of rows written since this database was created pub fn rows_written(&self) -> u64 { - self.inner.rows_written.load(Ordering::Relaxed) + self.rows_written.load(Ordering::Relaxed) } /// returns the total number of bytes used by the database (excluding uncheckpointed WAL entries) pub fn storage_bytes_used(&self) -> u64 { - self.inner.storage_bytes_used.load(Ordering::Relaxed) + self.storage_bytes_used.load(Ordering::Relaxed) } /// increments the number of the write requests which were delegated from a replica to primary pub fn inc_write_requests_delegated(&self) { - self.inner - .write_requests_delegated + self.write_requests_delegated .fetch_add(1, Ordering::Relaxed); } pub fn write_requests_delegated(&self) -> u64 { - self.inner.write_requests_delegated.load(Ordering::Relaxed) + self.write_requests_delegated.load(Ordering::Relaxed) } } -fn spawn_stats_persist_thread(stats: Arc, mut file: File) { - std::thread::spawn(move || loop { - if file.rewind().is_ok() { - file.set_len(0).unwrap(); - let _ = serde_json::to_writer(&mut file, &stats); +async fn spawn_stats_persist_thread(stats: Weak, path: PathBuf) -> anyhow::Result<()> { + loop { + if let Err(e) = try_persist_stats(stats.clone(), &path).await { + tracing::error!("error persisting stats file: {e}"); } - std::thread::sleep(Duration::from_secs(5)); - }); + tokio::time::sleep(Duration::from_secs(5)).await; + } +} + +async fn try_persist_stats(stats: Weak, path: &Path) -> anyhow::Result<()> { + let temp_path = path.with_extension("tmp"); + let mut file = tokio::fs::OpenOptions::new() + .write(true) + .create(true) + .open(&temp_path) + .await?; + file.set_len(0).await?; + file.write_all(&serde_json::to_vec(&stats)?).await?; + file.flush().await?; + tokio::fs::rename(temp_path, path).await?; + Ok(()) }