Skip to content
This repository was archived by the owner on Oct 18, 2023. It is now read-only.

Commit 93de94d

Browse files
authored
Abstract sqld networking (#650)
* refactor for generic networking * review edits
1 parent ad8e899 commit 93de94d

File tree

22 files changed

+1448
-1029
lines changed

22 files changed

+1448
-1029
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqld/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ mimalloc = { version = "0.1.36", default-features = false }
3434
nix = { version = "0.26.2", features = ["fs"] }
3535
once_cell = "1.17.0"
3636
parking_lot = "0.12.1"
37+
pin-project-lite = "0.2.13"
3738
priority-queue = "1.3"
3839
prost = "0.11.3"
3940
rand = "0.8"

sqld/src/admin_api.rs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use axum::Json;
55
use chrono::NaiveDateTime;
66
use futures::TryStreamExt;
77
use serde::Deserialize;
8+
use std::io::ErrorKind;
89
use std::sync::Arc;
9-
use std::{io::ErrorKind, net::SocketAddr};
1010
use tokio_util::io::ReaderStream;
1111
use url::Url;
1212
use uuid::Uuid;
@@ -17,14 +17,18 @@ use crate::namespace::{DumpStream, MakeNamespace, NamespaceStore, RestoreOption}
1717

1818
struct AppState<M: MakeNamespace> {
1919
db_config_store: Arc<DatabaseConfigStore>,
20-
namespaces: Arc<NamespaceStore<M>>,
20+
namespaces: NamespaceStore<M>,
2121
}
2222

23-
pub async fn run_admin_api<M: MakeNamespace>(
24-
addr: SocketAddr,
23+
pub async fn run_admin_api<M, A>(
24+
acceptor: A,
2525
db_config_store: Arc<DatabaseConfigStore>,
26-
namespaces: Arc<NamespaceStore<M>>,
27-
) -> anyhow::Result<()> {
26+
namespaces: NamespaceStore<M>,
27+
) -> anyhow::Result<()>
28+
where
29+
A: crate::net::Accept,
30+
M: MakeNamespace,
31+
{
2832
use axum::routing::{get, post};
2933
let router = axum::Router::new()
3034
.route("/", get(handle_get_index))
@@ -48,15 +52,10 @@ pub async fn run_admin_api<M: MakeNamespace>(
4852
namespaces,
4953
}));
5054

51-
let server = hyper::Server::try_bind(&addr)
52-
.context("Could not bind admin HTTP API server")?
53-
.serve(router.into_make_service());
54-
55-
tracing::info!(
56-
"Listening for admin HTTP API requests on {}",
57-
server.local_addr()
58-
);
59-
server.await?;
55+
hyper::server::Server::builder(acceptor)
56+
.serve(router.into_make_service())
57+
.await
58+
.context("Could not bind admin HTTP API server")?;
6059
Ok(())
6160
}
6261

sqld/src/config.rs

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
use std::net::SocketAddr;
2+
use std::path::{Path, PathBuf};
3+
use std::sync::Arc;
4+
use std::time::Duration;
5+
6+
use anyhow::Context;
7+
use hyper::client::HttpConnector;
8+
use sha256::try_digest;
9+
use tonic::transport::Channel;
10+
11+
use crate::auth::{self, Auth};
12+
use crate::net::{AddrIncoming, Connector};
13+
14+
pub struct RpcClientConfig<C = HttpConnector> {
15+
pub remote_url: String,
16+
pub connector: C,
17+
pub tls_config: Option<TlsConfig>,
18+
}
19+
20+
impl<C: Connector> RpcClientConfig<C> {
21+
pub(crate) async fn configure(self) -> anyhow::Result<(Channel, tonic::transport::Uri)> {
22+
let uri = tonic::transport::Uri::from_maybe_shared(self.remote_url)?;
23+
let mut builder = Channel::builder(uri.clone());
24+
if let Some(ref tls_config) = self.tls_config {
25+
let cert_pem = std::fs::read_to_string(&tls_config.cert)?;
26+
let key_pem = std::fs::read_to_string(&tls_config.key)?;
27+
let identity = tonic::transport::Identity::from_pem(cert_pem, key_pem);
28+
29+
let ca_cert_pem = std::fs::read_to_string(&tls_config.ca_cert)?;
30+
let ca_cert = tonic::transport::Certificate::from_pem(ca_cert_pem);
31+
32+
let tls_config = tonic::transport::ClientTlsConfig::new()
33+
.identity(identity)
34+
.ca_certificate(ca_cert)
35+
.domain_name("sqld");
36+
builder = builder.tls_config(tls_config)?;
37+
}
38+
39+
let channel = builder.connect_with_connector_lazy(self.connector);
40+
41+
Ok((channel, uri))
42+
}
43+
}
44+
45+
#[derive(Clone)]
46+
pub struct TlsConfig {
47+
pub cert: PathBuf,
48+
pub key: PathBuf,
49+
pub ca_cert: PathBuf,
50+
}
51+
52+
pub struct RpcServerConfig<A = AddrIncoming> {
53+
pub acceptor: A,
54+
pub addr: SocketAddr,
55+
pub tls_config: Option<TlsConfig>,
56+
}
57+
58+
pub struct UserApiConfig<A = AddrIncoming> {
59+
pub hrana_ws_acceptor: Option<A>,
60+
pub http_acceptor: Option<A>,
61+
pub enable_http_console: bool,
62+
pub self_url: Option<String>,
63+
pub http_auth: Option<String>,
64+
pub auth_jwt_key: Option<String>,
65+
}
66+
67+
impl<A> UserApiConfig<A> {
68+
pub fn get_auth(&self) -> anyhow::Result<Auth> {
69+
let mut auth = Auth::default();
70+
71+
if let Some(arg) = self.http_auth.as_deref() {
72+
if let Some(param) = auth::parse_http_basic_auth_arg(arg)? {
73+
auth.http_basic = Some(param);
74+
tracing::info!("Using legacy HTTP basic authentication");
75+
}
76+
}
77+
78+
if let Some(jwt_key) = self.auth_jwt_key.as_deref() {
79+
let jwt_key =
80+
auth::parse_jwt_key(jwt_key).context("Could not parse JWT decoding key")?;
81+
auth.jwt_key = Some(jwt_key);
82+
tracing::info!("Using JWT-based authentication");
83+
}
84+
85+
auth.disabled = auth.http_basic.is_none() && auth.jwt_key.is_none();
86+
if auth.disabled {
87+
tracing::warn!(
88+
"No authentication specified, the server will not require authentication"
89+
)
90+
}
91+
92+
Ok(auth)
93+
}
94+
}
95+
96+
pub struct AdminApiConfig<A = AddrIncoming> {
97+
pub acceptor: A,
98+
}
99+
100+
#[derive(Clone)]
101+
pub struct DbConfig {
102+
pub extensions_path: Option<Arc<Path>>,
103+
pub bottomless_replication: Option<bottomless::replicator::Options>,
104+
pub max_log_size: u64,
105+
pub max_log_duration: Option<f32>,
106+
pub soft_heap_limit_mb: Option<usize>,
107+
pub hard_heap_limit_mb: Option<usize>,
108+
pub max_response_size: u64,
109+
pub max_total_response_size: u64,
110+
pub snapshot_exec: Option<String>,
111+
pub checkpoint_interval: Option<Duration>,
112+
}
113+
114+
impl DbConfig {
115+
pub fn validate_extensions(&self) -> anyhow::Result<Arc<[PathBuf]>> {
116+
let mut valid_extensions = vec![];
117+
if let Some(ext_dir) = &self.extensions_path {
118+
let extensions_list = ext_dir.join("trusted.lst");
119+
120+
let file_contents = std::fs::read_to_string(&extensions_list)
121+
.with_context(|| format!("can't read {}", &extensions_list.display()))?;
122+
123+
let extensions = file_contents.lines().filter(|c| !c.is_empty());
124+
125+
for line in extensions {
126+
let mut ext_info = line.trim().split_ascii_whitespace();
127+
128+
let ext_sha = ext_info.next().ok_or_else(|| {
129+
anyhow::anyhow!("invalid line on {}: {}", &extensions_list.display(), line)
130+
})?;
131+
let ext_fname = ext_info.next().ok_or_else(|| {
132+
anyhow::anyhow!("invalid line on {}: {}", &extensions_list.display(), line)
133+
})?;
134+
135+
anyhow::ensure!(
136+
ext_info.next().is_none(),
137+
"extension list seem to contain a filename with whitespaces. Rejected"
138+
);
139+
140+
let extension_full_path = ext_dir.join(ext_fname);
141+
let digest = try_digest(extension_full_path.as_path()).with_context(|| {
142+
format!(
143+
"Failed to get sha256 digest, while trying to read {}",
144+
extension_full_path.display()
145+
)
146+
})?;
147+
148+
anyhow::ensure!(
149+
digest == ext_sha,
150+
"sha256 differs for {}. Got {}",
151+
ext_fname,
152+
digest
153+
);
154+
valid_extensions.push(extension_full_path);
155+
}
156+
}
157+
158+
Ok(valid_extensions.into())
159+
}
160+
}
161+
162+
pub struct HeartbeatConfig {
163+
pub heartbeat_url: String,
164+
pub heartbeat_period: Duration,
165+
pub heartbeat_auth: Option<String>,
166+
}

sqld/src/connection/libsql.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub struct LibSqlDbFactory<W: WalHook + 'static> {
3030
ctx_builder: Box<dyn Fn() -> W::Context + Sync + Send + 'static>,
3131
stats: Stats,
3232
config_store: Arc<DatabaseConfigStore>,
33-
extensions: Vec<PathBuf>,
33+
extensions: Arc<[PathBuf]>,
3434
max_response_size: u64,
3535
max_total_response_size: u64,
3636
auto_checkpoint: u32,
@@ -51,7 +51,7 @@ where
5151
ctx_builder: F,
5252
stats: Stats,
5353
config_store: Arc<DatabaseConfigStore>,
54-
extensions: Vec<PathBuf>,
54+
extensions: Arc<[PathBuf]>,
5555
max_response_size: u64,
5656
max_total_response_size: u64,
5757
auto_checkpoint: u32,
@@ -165,7 +165,7 @@ where
165165
impl LibSqlConnection {
166166
pub async fn new<W>(
167167
path: impl AsRef<Path> + Send + 'static,
168-
extensions: Vec<PathBuf>,
168+
extensions: Arc<[PathBuf]>,
169169
wal_hook: &'static WalMethodsHook<W>,
170170
hook_ctx: W::Context,
171171
stats: Stats,
@@ -250,7 +250,7 @@ struct Connection<'a> {
250250
impl<'a> Connection<'a> {
251251
fn new<W: WalHook>(
252252
path: &Path,
253-
extensions: Vec<PathBuf>,
253+
extensions: Arc<[PathBuf]>,
254254
wal_methods: &'static WalMethodsHook<W>,
255255
hook_ctx: &'a mut W::Context,
256256
stats: Stats,
@@ -272,10 +272,10 @@ impl<'a> Connection<'a> {
272272
builder_config,
273273
};
274274

275-
for ext in extensions {
275+
for ext in extensions.iter() {
276276
unsafe {
277277
let _guard = rusqlite::LoadExtensionGuard::new(&this.conn).unwrap();
278-
if let Err(e) = this.conn.load_extension(&ext, None) {
278+
if let Err(e) = this.conn.load_extension(ext, None) {
279279
tracing::error!("failed to load extension: {}", ext.display());
280280
Err(e)?;
281281
}

sqld/src/connection/write_proxy.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use super::{MakeConnection, Program};
3636
pub struct MakeWriteProxyConnection {
3737
client: ProxyClient<Channel>,
3838
db_path: PathBuf,
39-
extensions: Vec<PathBuf>,
39+
extensions: Arc<[PathBuf]>,
4040
stats: Stats,
4141
config_store: Arc<DatabaseConfigStore>,
4242
applied_frame_no_receiver: watch::Receiver<FrameNo>,
@@ -49,7 +49,7 @@ impl MakeWriteProxyConnection {
4949
#[allow(clippy::too_many_arguments)]
5050
pub fn new(
5151
db_path: PathBuf,
52-
extensions: Vec<PathBuf>,
52+
extensions: Arc<[PathBuf]>,
5353
channel: Channel,
5454
uri: tonic::transport::Uri,
5555
stats: Stats,
@@ -165,7 +165,7 @@ impl WriteProxyConnection {
165165
async fn new(
166166
write_proxy: ProxyClient<Channel>,
167167
db_path: PathBuf,
168-
extensions: Vec<PathBuf>,
168+
extensions: Arc<[PathBuf]>,
169169
stats: Stats,
170170
config_store: Arc<DatabaseConfigStore>,
171171
applied_frame_no_receiver: watch::Receiver<FrameNo>,

sqld/src/http/h2c.rs renamed to sqld/src/h2c.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ use std::pin::Pin;
4242

4343
use axum::{body::BoxBody, http::HeaderValue};
4444
use hyper::header;
45-
use hyper::server::conn::AddrStream;
4645
use hyper::Body;
4746
use hyper::{Request, Response};
4847
use tonic::transport::server::TcpConnectInfo;
@@ -63,12 +62,13 @@ impl<S> H2cMaker<S> {
6362
}
6463
}
6564

66-
impl<S> Service<&AddrStream> for H2cMaker<S>
65+
impl<S, C> Service<&C> for H2cMaker<S>
6766
where
6867
S: Service<Request<Body>, Response = Response<BoxBody>> + Clone + Send + 'static,
6968
S::Future: Send + 'static,
7069
S::Error: Into<BoxError> + Sync + Send + 'static,
7170
S::Response: Send + 'static,
71+
C: crate::net::Conn,
7272
{
7373
type Response = H2c<S>;
7474

@@ -84,11 +84,8 @@ where
8484
std::task::Poll::Ready(Ok(()))
8585
}
8686

87-
fn call(&mut self, conn: &AddrStream) -> Self::Future {
88-
let connect_info = TcpConnectInfo {
89-
local_addr: Some(conn.local_addr()),
90-
remote_addr: Some(conn.remote_addr()),
91-
};
87+
fn call(&mut self, conn: &C) -> Self::Future {
88+
let connect_info = conn.connect_info();
9289
let s = self.s.clone();
9390
Box::pin(async move { Ok(H2c { s, connect_info }) })
9491
}

sqld/src/hrana/ws/conn.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ struct ResponseFuture {
5050

5151
pub(super) async fn handle_tcp<F: MakeNamespace>(
5252
server: Arc<Server<F>>,
53-
socket: tokio::net::TcpStream,
53+
socket: Box<dyn crate::net::Conn>,
5454
conn_id: u64,
5555
) -> Result<()> {
5656
let handshake::Output {

sqld/src/hrana/ws/handshake.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ use tokio_tungstenite::tungstenite;
55
use tungstenite::http;
66

77
use crate::http::db_factory::namespace_from_headers;
8+
use crate::net::Conn;
89

910
use super::super::{Encoding, Version};
1011
use super::Upgrade;
1112

12-
#[derive(Debug)]
1313
pub enum WebSocket {
14-
Tcp(tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>),
14+
Tcp(tokio_tungstenite::WebSocketStream<Box<dyn Conn>>),
1515
Upgraded(tokio_tungstenite::WebSocketStream<hyper::upgrade::Upgraded>),
1616
}
1717

@@ -23,7 +23,6 @@ enum Subproto {
2323
Hrana3Protobuf,
2424
}
2525

26-
#[derive(Debug)]
2726
pub struct Output {
2827
pub ws: WebSocket,
2928
pub version: Version,
@@ -32,14 +31,10 @@ pub struct Output {
3231
}
3332

3433
pub async fn handshake_tcp(
35-
socket: tokio::net::TcpStream,
34+
socket: Box<dyn Conn>,
3635
disable_default_ns: bool,
3736
disable_namespaces: bool,
3837
) -> Result<Output> {
39-
socket
40-
.set_nodelay(true)
41-
.context("Could not disable Nagle's algorithm")?;
42-
4338
let mut subproto = None;
4439
let mut namespace = None;
4540
let callback = |req: &http::Request<()>, resp: http::Response<()>| {

0 commit comments

Comments
 (0)