diff --git a/libsql-server/src/http/user/db_factory.rs b/libsql-server/src/http/user/db_factory.rs index 2a36024d5c..2a7c4c5752 100644 --- a/libsql-server/src/http/user/db_factory.rs +++ b/libsql-server/src/http/user/db_factory.rs @@ -50,6 +50,8 @@ pub fn namespace_from_headers( if let Some(from_metadata) = headers.get(NAMESPACE_METADATA_KEY) { try_namespace_from_metadata(from_metadata) + } else if let Some(from_ns_header) = headers.get("x-namespace") { + try_namespace_from_header(from_ns_header) } else if let Some(from_host) = headers.get("host") { try_namespace_from_host(from_host, disable_default_namespace) } else if !disable_default_namespace { @@ -59,6 +61,11 @@ pub fn namespace_from_headers( } } +fn try_namespace_from_header(header: &axum::http::HeaderValue) -> Result { + NamespaceName::from_bytes(header.as_bytes().to_vec().into()) + .map_err(|_| Error::InvalidNamespace) +} + fn try_namespace_from_host( from_host: &axum::http::HeaderValue, disable_default_namespace: bool, diff --git a/libsql-server/tests/embedded_replica/mod.rs b/libsql-server/tests/embedded_replica/mod.rs index e7b4b9f7f0..0c4a1c00a0 100644 --- a/libsql-server/tests/embedded_replica/mod.rs +++ b/libsql-server/tests/embedded_replica/mod.rs @@ -1696,3 +1696,43 @@ fn schema_db() { sim.run().unwrap(); } + +#[test] +fn remote_namespace_header_support() { + let tmp_host = tempdir().unwrap(); + let tmp_host_path = tmp_host.path().to_owned(); + + let mut sim = Builder::new() + .simulation_duration(Duration::from_secs(1000)) + .build(); + + make_primary(&mut sim, tmp_host_path.clone()); + + sim.client("client", async move { + let client = Client::new(); + + client + .post("http://primary:9090/v1/namespaces/foo/create", json!({})) + .await?; + + let db_url = "http://primary:8080"; + + let remote = libsql::Builder::new_remote(db_url.to_string(), String::new()) + .namespace("foo") + .connector(TurmoilConnector) + .build() + .await + .unwrap(); + + let conn = remote.connect().unwrap(); + + conn.execute("CREATE TABLE user (id INTEGER NOT NULL PRIMARY KEY)", ()) + .await?; + + conn.execute("INSERT into user(id) values (1);", ()).await?; + + Ok(()) + }); + + sim.run().unwrap(); +} diff --git a/libsql/src/database.rs b/libsql/src/database.rs index b93ea66e98..caa772d25b 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -88,6 +88,7 @@ enum DbType { auth_token: String, connector: crate::util::ConnectorService, version: Option, + namespace: Option, }, } @@ -230,6 +231,7 @@ cfg_replication! { OpenFlags::default(), encryption_config.clone(), None, + None, ).await?; Ok(Database { @@ -514,6 +516,7 @@ cfg_remote! { auth_token: auth_token.into(), connector: crate::util::ConnectorService::new(svc), version, + namespace: None, }, max_write_replication_index: Default::default(), }) @@ -650,6 +653,7 @@ impl Database { auth_token, connector, version, + namespace, } => { let conn = std::sync::Arc::new( crate::hrana::connection::HttpConnection::new_with_connector( @@ -657,6 +661,7 @@ impl Database { auth_token, connector.clone(), version.as_ref().map(|s| s.as_str()), + namespace.as_ref().map(|s| s.as_str()), ), ); diff --git a/libsql/src/database/builder.rs b/libsql/src/database/builder.rs index f1c60ffb3c..e835a28075 100644 --- a/libsql/src/database/builder.rs +++ b/libsql/src/database/builder.rs @@ -59,12 +59,12 @@ impl Builder<()> { auth_token, connector: None, version: None, + namespace: None, }, encryption_config: None, read_your_writes: true, sync_interval: None, http_request_callback: None, - namespace: None }, } } @@ -99,6 +99,7 @@ impl Builder<()> { auth_token, connector: None, version: None, + namespace: None, }, connector:None, }, @@ -115,6 +116,7 @@ impl Builder<()> { auth_token, connector: None, version: None, + namespace: None, }, } } @@ -128,6 +130,7 @@ cfg_replication_or_remote_or_sync! { auth_token: String, connector: Option, version: Option, + namespace: Option, } } @@ -195,7 +198,6 @@ cfg_replication! { read_your_writes: bool, sync_interval: Option, http_request_callback: Option, - namespace: Option, } /// Local replica configuration type in [`Builder`]. @@ -260,7 +262,7 @@ cfg_replication! { /// Set the namespace that will be communicated to remote replica in the http header. pub fn namespace(mut self, namespace: impl Into) -> Builder { - self.inner.namespace = Some(namespace.into()); + self.inner.remote.namespace = Some(namespace.into()); self } @@ -280,12 +282,12 @@ cfg_replication! { auth_token, connector, version, + namespace, }, encryption_config, read_your_writes, sync_interval, http_request_callback, - namespace } = self.inner; let connector = if let Some(connector) = connector { @@ -357,6 +359,7 @@ cfg_replication! { auth_token, connector, version, + namespace, }) = remote { let connector = if let Some(connector) = connector { @@ -381,6 +384,7 @@ cfg_replication! { flags, encryption_config.clone(), http_request_callback, + namespace, ) .await? } else { @@ -434,6 +438,7 @@ cfg_sync! { auth_token, connector: _, version: _, + namespace: _, }, connector, } = self.inner; @@ -490,6 +495,13 @@ cfg_remote! { self } + /// Set the namespace that will be communicated to the remote in the http header. + pub fn namespace(mut self, namespace: impl Into) -> Builder + { + self.inner.namespace = Some(namespace.into()); + self + } + /// Build the remote database client. pub async fn build(self) -> Result { let Remote { @@ -497,6 +509,7 @@ cfg_remote! { auth_token, connector, version, + namespace, } = self.inner; let connector = if let Some(connector) = connector { @@ -518,6 +531,7 @@ cfg_remote! { auth_token, connector, version, + namespace, }, max_write_replication_index: Default::default(), }) diff --git a/libsql/src/hrana/hyper.rs b/libsql/src/hrana/hyper.rs index 7ebab81a6d..c90477ea75 100644 --- a/libsql/src/hrana/hyper.rs +++ b/libsql/src/hrana/hyper.rs @@ -25,17 +25,27 @@ pub type ByteStream = Box> + Send + Syn pub struct HttpSender { inner: hyper::Client, version: HeaderValue, + namespace: Option, } impl HttpSender { - pub fn new(connector: ConnectorService, version: Option<&str>) -> Self { + pub fn new( + connector: ConnectorService, + version: Option<&str>, + namespace: Option<&str>, + ) -> Self { let ver = version.unwrap_or(env!("CARGO_PKG_VERSION")); let version = HeaderValue::try_from(format!("libsql-remote-{ver}")).unwrap(); + let namespace = namespace.map(|v| HeaderValue::try_from(v).unwrap()); let inner = hyper::Client::builder().build(connector); - Self { inner, version } + Self { + inner, + version, + namespace, + } } async fn send( @@ -44,9 +54,15 @@ impl HttpSender { auth: Arc, body: String, ) -> Result> { - let req = hyper::Request::post(url.as_ref()) + let mut req_builder = hyper::Request::post(url.as_ref()) .header(AUTHORIZATION, auth.as_ref()) - .header("x-libsql-client-version", self.version.clone()) + .header("x-libsql-client-version", self.version.clone()); + + if let Some(namespace) = self.namespace { + req_builder = req_builder.header("x-namespace", namespace); + } + + let req = req_builder .body(hyper::Body::from(body)) .map_err(|err| HranaError::Http(format!("{:?}", err)))?; @@ -107,8 +123,9 @@ impl HttpConnection { token: impl Into, connector: ConnectorService, version: Option<&str>, + namespace: Option<&str>, ) -> Self { - let inner = HttpSender::new(connector, version); + let inner = HttpSender::new(connector, version, namespace); Self::new(url.into(), token.into(), inner) } } diff --git a/libsql/src/local/database.rs b/libsql/src/local/database.rs index 64f09fcc12..79563efc4c 100644 --- a/libsql/src/local/database.rs +++ b/libsql/src/local/database.rs @@ -190,6 +190,7 @@ impl Database { flags: OpenFlags, encryption_config: Option, http_request_callback: Option, + namespace: Option, ) -> Result { use std::path::PathBuf; @@ -208,7 +209,7 @@ impl Database { auth_token, version.as_deref(), http_request_callback, - None, + namespace, ) .map_err(|e| crate::Error::Replication(e.into()))?;