Skip to content

Commit

Permalink
Add support for namespace to the remote connection builder
Browse files Browse the repository at this point in the history
  • Loading branch information
jameswritescode committed Dec 4, 2024
1 parent 9241b00 commit 91611e7
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 10 deletions.
7 changes: 7 additions & 0 deletions libsql-server/src/http/user/db_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ pub fn namespace_from_headers(

if let Some(from_metadata) = headers.get(NAMESPACE_METADATA_KEY) {
try_namespace_from_metadata(from_metadata)
} else if let Some(from_ns_header) = headers.get("x-namespace") {
try_namespace_from_header(from_ns_header)
} else if let Some(from_host) = headers.get("host") {
try_namespace_from_host(from_host, disable_default_namespace)
} else if !disable_default_namespace {
Expand All @@ -59,6 +61,11 @@ pub fn namespace_from_headers(
}
}

fn try_namespace_from_header(header: &axum::http::HeaderValue) -> Result<NamespaceName, Error> {
NamespaceName::from_bytes(header.as_bytes().to_vec().into())
.map_err(|_| Error::InvalidNamespace)
}

fn try_namespace_from_host(
from_host: &axum::http::HeaderValue,
disable_default_namespace: bool,
Expand Down
40 changes: 40 additions & 0 deletions libsql-server/tests/embedded_replica/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1696,3 +1696,43 @@ fn schema_db() {

sim.run().unwrap();
}

#[test]
fn remote_namespace_header_support() {
let tmp_host = tempdir().unwrap();
let tmp_host_path = tmp_host.path().to_owned();

let mut sim = Builder::new()
.simulation_duration(Duration::from_secs(1000))
.build();

make_primary(&mut sim, tmp_host_path.clone());

sim.client("client", async move {
let client = Client::new();

client
.post("http://primary:9090/v1/namespaces/foo/create", json!({}))
.await?;

let db_url = "http://primary:8080";

let remote = libsql::Builder::new_remote(db_url.to_string(), String::new())
.namespace("foo")
.connector(TurmoilConnector)
.build()
.await
.unwrap();

let conn = remote.connect().unwrap();

conn.execute("CREATE TABLE user (id INTEGER NOT NULL PRIMARY KEY)", ())
.await?;

conn.execute("INSERT into user(id) values (1);", ()).await?;

Ok(())
});

sim.run().unwrap();
}
5 changes: 5 additions & 0 deletions libsql/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ enum DbType {
auth_token: String,
connector: crate::util::ConnectorService,
version: Option<String>,
namespace: Option<String>,
},
}

Expand Down Expand Up @@ -230,6 +231,7 @@ cfg_replication! {
OpenFlags::default(),
encryption_config.clone(),
None,
None,
).await?;

Ok(Database {
Expand Down Expand Up @@ -514,6 +516,7 @@ cfg_remote! {
auth_token: auth_token.into(),
connector: crate::util::ConnectorService::new(svc),
version,
namespace: None,
},
max_write_replication_index: Default::default(),
})
Expand Down Expand Up @@ -650,13 +653,15 @@ impl Database {
auth_token,
connector,
version,
namespace,
} => {
let conn = std::sync::Arc::new(
crate::hrana::connection::HttpConnection::new_with_connector(
url,
auth_token,
connector.clone(),
version.as_ref().map(|s| s.as_str()),
namespace.as_ref().map(|s| s.as_str()),
),
);

Expand Down
22 changes: 18 additions & 4 deletions libsql/src/database/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ impl Builder<()> {
auth_token,
connector: None,
version: None,
namespace: None,
},
encryption_config: None,
read_your_writes: true,
sync_interval: None,
http_request_callback: None,
namespace: None
},
}
}
Expand Down Expand Up @@ -99,6 +99,7 @@ impl Builder<()> {
auth_token,
connector: None,
version: None,
namespace: None,
},
connector:None,
},
Expand All @@ -115,6 +116,7 @@ impl Builder<()> {
auth_token,
connector: None,
version: None,
namespace: None,
},
}
}
Expand All @@ -128,6 +130,7 @@ cfg_replication_or_remote_or_sync! {
auth_token: String,
connector: Option<crate::util::ConnectorService>,
version: Option<String>,
namespace: Option<String>,
}
}

Expand Down Expand Up @@ -195,7 +198,6 @@ cfg_replication! {
read_your_writes: bool,
sync_interval: Option<std::time::Duration>,
http_request_callback: Option<crate::util::HttpRequestCallback>,
namespace: Option<String>,
}

/// Local replica configuration type in [`Builder`].
Expand Down Expand Up @@ -260,7 +262,7 @@ cfg_replication! {
/// Set the namespace that will be communicated to remote replica in the http header.
pub fn namespace(mut self, namespace: impl Into<String>) -> Builder<RemoteReplica>
{
self.inner.namespace = Some(namespace.into());
self.inner.remote.namespace = Some(namespace.into());
self
}

Expand All @@ -280,12 +282,12 @@ cfg_replication! {
auth_token,
connector,
version,
namespace,
},
encryption_config,
read_your_writes,
sync_interval,
http_request_callback,
namespace
} = self.inner;

let connector = if let Some(connector) = connector {
Expand Down Expand Up @@ -357,6 +359,7 @@ cfg_replication! {
auth_token,
connector,
version,
namespace,
}) = remote
{
let connector = if let Some(connector) = connector {
Expand All @@ -381,6 +384,7 @@ cfg_replication! {
flags,
encryption_config.clone(),
http_request_callback,
namespace,
)
.await?
} else {
Expand Down Expand Up @@ -434,6 +438,7 @@ cfg_sync! {
auth_token,
connector: _,
version: _,
namespace: _,
},
connector,
} = self.inner;
Expand Down Expand Up @@ -490,13 +495,21 @@ cfg_remote! {
self
}

/// Set the namespace that will be communicated to the remote in the http header.
pub fn namespace(mut self, namespace: impl Into<String>) -> Builder<Remote>
{
self.inner.namespace = Some(namespace.into());
self
}

/// Build the remote database client.
pub async fn build(self) -> Result<Database> {
let Remote {
url,
auth_token,
connector,
version,
namespace,
} = self.inner;

let connector = if let Some(connector) = connector {
Expand All @@ -518,6 +531,7 @@ cfg_remote! {
auth_token,
connector,
version,
namespace,
},
max_write_replication_index: Default::default(),
})
Expand Down
27 changes: 22 additions & 5 deletions libsql/src/hrana/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,27 @@ pub type ByteStream = Box<dyn Stream<Item = std::io::Result<Bytes>> + Send + Syn
pub struct HttpSender {
inner: hyper::Client<ConnectorService, hyper::Body>,
version: HeaderValue,
namespace: Option<HeaderValue>,
}

impl HttpSender {
pub fn new(connector: ConnectorService, version: Option<&str>) -> Self {
pub fn new(
connector: ConnectorService,
version: Option<&str>,
namespace: Option<&str>,
) -> Self {
let ver = version.unwrap_or(env!("CARGO_PKG_VERSION"));

let version = HeaderValue::try_from(format!("libsql-remote-{ver}")).unwrap();
let namespace = namespace.map(|v| HeaderValue::try_from(v).unwrap());

let inner = hyper::Client::builder().build(connector);

Self { inner, version }
Self {
inner,
version,
namespace,
}
}

async fn send(
Expand All @@ -44,9 +54,15 @@ impl HttpSender {
auth: Arc<str>,
body: String,
) -> Result<super::HttpBody<ByteStream>> {
let req = hyper::Request::post(url.as_ref())
let mut req_builder = hyper::Request::post(url.as_ref())
.header(AUTHORIZATION, auth.as_ref())
.header("x-libsql-client-version", self.version.clone())
.header("x-libsql-client-version", self.version.clone());

if let Some(namespace) = self.namespace {
req_builder = req_builder.header("x-namespace", namespace);
}

let req = req_builder
.body(hyper::Body::from(body))
.map_err(|err| HranaError::Http(format!("{:?}", err)))?;

Expand Down Expand Up @@ -107,8 +123,9 @@ impl HttpConnection<HttpSender> {
token: impl Into<String>,
connector: ConnectorService,
version: Option<&str>,
namespace: Option<&str>,
) -> Self {
let inner = HttpSender::new(connector, version);
let inner = HttpSender::new(connector, version, namespace);
Self::new(url.into(), token.into(), inner)
}
}
Expand Down
3 changes: 2 additions & 1 deletion libsql/src/local/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ impl Database {
flags: OpenFlags,
encryption_config: Option<EncryptionConfig>,
http_request_callback: Option<crate::util::HttpRequestCallback>,
namespace: Option<String>,
) -> Result<Database> {
use std::path::PathBuf;

Expand All @@ -208,7 +209,7 @@ impl Database {
auth_token,
version.as_deref(),
http_request_callback,
None,
namespace,
)
.map_err(|e| crate::Error::Replication(e.into()))?;

Expand Down

0 comments on commit 91611e7

Please sign in to comment.