Skip to content
This repository was archived by the owner on Oct 18, 2023. It is now read-only.

write proxy: expire old write delegation streams #723

Merged
merged 2 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqld/src/connection/libsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ where
}
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct LibSqlConnection {
sender: crossbeam::channel::Sender<ExecCallback>,
}
Expand Down
29 changes: 27 additions & 2 deletions sqld/src/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::time::Duration;

Expand Down Expand Up @@ -231,6 +231,14 @@ impl Drop for WaitersGuard<'_> {
}
}

fn now_millis() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64
}

#[async_trait::async_trait]
impl<F: MakeConnection> MakeConnection for MakeThrottledConnection<F> {
type Connection = TrackedConnection<F::Connection>;
Expand Down Expand Up @@ -266,14 +274,28 @@ impl<F: MakeConnection> MakeConnection for MakeThrottledConnection<F> {
}

let inner = self.connection_maker.create().await?;
Ok(TrackedConnection { permit, inner })
Ok(TrackedConnection {
permit,
inner,
atime: AtomicU64::new(now_millis()),
})
}
}

#[derive(Debug)]
pub struct TrackedConnection<DB> {
inner: DB,
#[allow(dead_code)] // just hold on to it
permit: tokio::sync::OwnedSemaphorePermit,
atime: AtomicU64,
}

impl<DB> TrackedConnection<DB> {
pub fn idle_time(&self) -> Duration {
let now = now_millis();
let atime = self.atime.load(Ordering::Relaxed);
Duration::from_millis(now.saturating_sub(atime))
}
}

#[async_trait::async_trait]
Expand All @@ -286,6 +308,7 @@ impl<DB: Connection> Connection for TrackedConnection<DB> {
builder: B,
replication_index: Option<FrameNo>,
) -> crate::Result<(B, State)> {
self.atime.store(now_millis(), Ordering::Relaxed);
self.inner
.execute_program(pgm, auth, builder, replication_index)
.await
Expand All @@ -298,6 +321,7 @@ impl<DB: Connection> Connection for TrackedConnection<DB> {
auth: Authenticated,
replication_index: Option<FrameNo>,
) -> crate::Result<DescribeResult> {
self.atime.store(now_millis(), Ordering::Relaxed);
self.inner.describe(sql, auth, replication_index).await
}

Expand All @@ -308,6 +332,7 @@ impl<DB: Connection> Connection for TrackedConnection<DB> {

#[inline]
async fn checkpoint(&self) -> Result<()> {
self.atime.store(now_millis(), Ordering::Relaxed);
self.inner.checkpoint().await
}
}
Expand Down
24 changes: 23 additions & 1 deletion sqld/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,20 @@ where
}

if let Some(config) = self.rpc_config.take() {
let proxy_service =
ProxyService::new(namespaces.clone(), None, self.disable_namespaces);
// Garbage collect proxy clients every 30 seconds
self.join_set.spawn({
let clients = proxy_service.clients();
async move {
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
rpc::proxy::garbage_collect(&mut *clients.write().await).await;
}
}
});
self.join_set.spawn(run_rpc_server(
proxy_service,
config.acceptor,
config.tls_config,
self.idle_shutdown_kicker.clone(),
Expand All @@ -498,7 +511,16 @@ where

let proxy_service =
ProxyService::new(namespaces.clone(), Some(self.auth), self.disable_namespaces);

// Garbage collect proxy clients every 30 seconds
self.join_set.spawn({
let clients = proxy_service.clients();
async move {
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
rpc::proxy::garbage_collect(&mut *clients.write().await).await;
}
}
});
Ok((namespaces, proxy_service, logger_service))
}
}
Expand Down
3 changes: 1 addition & 2 deletions sqld/src/rpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@ pub mod replication_log_proxy;
pub const NAMESPACE_DOESNT_EXIST: &str = "NAMESPACE_DOESNT_EXIST";
pub(crate) const NAMESPACE_METADATA_KEY: &str = "x-namespace-bin";

#[allow(clippy::too_many_arguments)]
pub async fn run_rpc_server<A: crate::net::Accept>(
proxy_service: ProxyService,
acceptor: A,
maybe_tls: Option<TlsConfig>,
idle_shutdown_layer: Option<IdleShutdownKicker>,
namespaces: NamespaceStore<PrimaryNamespaceMaker>,
disable_namespaces: bool,
) -> anyhow::Result<()> {
let proxy_service = ProxyService::new(namespaces.clone(), None, disable_namespaces);
let logger_service = ReplicationLogService::new(
namespaces.clone(),
idle_shutdown_layer.clone(),
Expand Down
23 changes: 21 additions & 2 deletions sqld/src/rpc/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ pub mod rpc {
}

pub struct ProxyService {
clients: RwLock<HashMap<Uuid, Arc<TrackedConnection<LibSqlConnection>>>>,
clients: Arc<RwLock<HashMap<Uuid, Arc<TrackedConnection<LibSqlConnection>>>>>,
namespaces: NamespaceStore<PrimaryNamespaceMaker>,
auth: Option<Arc<Auth>>,
disable_namespaces: bool,
Expand All @@ -277,13 +277,19 @@ impl ProxyService {
auth: Option<Arc<Auth>>,
disable_namespaces: bool,
) -> Self {
let clients: Arc<RwLock<HashMap<Uuid, Arc<TrackedConnection<LibSqlConnection>>>>> =
Default::default();
Self {
clients: Default::default(),
clients,
namespaces,
auth,
disable_namespaces,
}
}

pub fn clients(&self) -> Arc<RwLock<HashMap<Uuid, Arc<TrackedConnection<LibSqlConnection>>>>> {
self.clients.clone()
}
}

#[derive(Debug, Default)]
Expand Down Expand Up @@ -441,6 +447,19 @@ impl QueryResultBuilder for ExecuteResultBuilder {
}
}

// Disconnects all clients that have been idle for more than 30 seconds.
// FIXME: we should also keep a list of recently disconnected clients,
// and if one should arrive with a late message, it should be rejected
// with an error. A similar mechanism is already implemented in hrana-over-http.
pub async fn garbage_collect(
clients: &mut HashMap<Uuid, Arc<TrackedConnection<LibSqlConnection>>>,
) {
let limit = std::time::Duration::from_secs(30);

clients.retain(|_, db| db.idle_time() < limit);
tracing::trace!("gc: remaining client handles: {:?}", clients);
}

#[tonic::async_trait]
impl Proxy for ProxyService {
async fn execute(
Expand Down