diff --git a/Cargo.toml b/Cargo.toml index 0c6ae6e9..9803a336 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,8 +9,12 @@ askama = "0.12.1" atty = "0.2.14" axum = "0.7.7" chrono = "0.4.40" +dashmap = "6.1" fern = {version = "0.7.1", features = ["colored"]} +futures = "0.3" gethostname = "0.5.0" +hyper = "0.14" +http = "0.2" log = "0.4.22" prost = "0.13.3" prost-types = "0.13.3" @@ -21,7 +25,9 @@ slog-stdlog = "4.1.1" stderrlog = "0.6.0" structopt = "0.3.26" tokio = {version = "1.40.0", features = ["full", "test-util", "tracing", "macros", "rt-multi-thread"] } +tokio-stream = "0.1" tonic = "0.12.2" +tower = "0.4" [build-dependencies] tonic-build = "0.12.2" diff --git a/src/bin/lighthouse.rs b/src/bin/lighthouse.rs index dbce458b..70c13f7c 100644 --- a/src/bin/lighthouse.rs +++ b/src/bin/lighthouse.rs @@ -4,8 +4,11 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. +use std::net::SocketAddr; use structopt::StructOpt; -use torchft::lighthouse::{Lighthouse, LighthouseOpt}; +use tonic::transport::Server; +use torchft::lighthouse::LighthouseOpt; +use torchft::router::Router; #[tokio::main(flavor = "multi_thread", worker_threads = 4)] async fn main() { @@ -17,7 +20,12 @@ async fn main() { .unwrap(); let opt = LighthouseOpt::from_args(); - let lighthouse = Lighthouse::new(opt).await.unwrap(); + let router = Router::new(opt.clone()); + let addr: SocketAddr = opt.bind.parse().expect("invalid --bind address"); - lighthouse.run().await.unwrap(); + Server::builder() + .add_service(router) + .serve(addr) + .await + .unwrap(); } diff --git a/src/interceptor.rs b/src/interceptor.rs new file mode 100644 index 00000000..fa7b7c14 --- /dev/null +++ b/src/interceptor.rs @@ -0,0 +1,23 @@ +use tonic::{metadata::MetadataValue, service::Interceptor, Request, Status}; + +/// Attaches user-assigned room-id header to every outbound RPC +#[derive(Clone)] +pub struct RoomIdInterceptor { + room: String, +} + +impl RoomIdInterceptor { + pub fn new(room: String) -> Self { + Self { room } + } +} + +impl Interceptor for RoomIdInterceptor { + fn call(&mut self, mut req: Request<()>) -> Result, Status> { + req.metadata_mut().insert( + crate::router::ROOM_ID_HEADER, + MetadataValue::try_from(self.room.as_str()).expect("ascii header"), + ); + Ok(req) + } +} diff --git a/src/lib.rs b/src/lib.rs index 32a7a37e..7d46cb09 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,24 +4,32 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. +mod interceptor; pub mod lighthouse; pub mod manager; mod net; mod retry; +pub mod router; mod timeout; +pub use crate::router::Router; + use anyhow::Result; use atty::Stream; use core::time::Duration; +use gethostname::gethostname; use pyo3::exceptions::{PyRuntimeError, PyTimeoutError}; use std::cmp; use std::env; +use std::net::SocketAddr; use std::sync::Arc; use std::thread::available_parallelism; use structopt::StructOpt; use tokio::runtime::Runtime; use tokio::task::JoinHandle; -use tonic::transport::Channel; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::service::interceptor::InterceptedService; +use tonic::transport::{Channel, Endpoint}; use tonic::Status; use chrono::Local; @@ -32,6 +40,7 @@ pub mod torchftpb { tonic::include_proto!("torchft"); } +use crate::interceptor::RoomIdInterceptor; use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; use crate::torchftpb::manager_service_client::ManagerServiceClient; use crate::torchftpb::{ @@ -339,9 +348,13 @@ fn lighthouse_main(py: Python<'_>) -> PyResult<()> { } async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> { - let lighthouse = lighthouse::Lighthouse::new(opt).await?; + let router = Router::new(opt.clone()); + let addr: SocketAddr = opt.bind.parse()?; - lighthouse.run().await?; + tonic::transport::Server::builder() + .add_service(router) + .serve(addr) + .await?; Ok(()) } @@ -477,28 +490,39 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult { /// connect_timeout (timedelta): The timeout for connecting to the lighthouse server. #[pyclass] struct LighthouseClient { - client: LighthouseServiceClient, + client: LighthouseServiceClient>, runtime: Runtime, } #[pymethods] impl LighthouseClient { - #[pyo3(signature = (addr, connect_timeout))] + #[pyo3(signature = (addr, connect_timeout, room_id = None))] #[new] - fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult { + fn new( + py: Python<'_>, + addr: String, + connect_timeout: Duration, + room_id: Option, + ) -> PyResult { py.allow_threads(move || { let runtime = tokio::runtime::Builder::new_multi_thread() .worker_threads(num_threads()) .thread_name("torchft-lhclnt") .enable_all() .build()?; - let client = runtime - .block_on(manager::lighthouse_client_new(addr, connect_timeout)) + + let endpoint = Endpoint::from_shared(addr.clone()) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - Ok(Self { - client: client, - runtime: runtime, - }) + let channel = runtime + .block_on(endpoint.connect_timeout(connect_timeout).connect()) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + let interceptor = + RoomIdInterceptor::new(room_id.unwrap_or_else(|| "default".to_owned())); + + let client = LighthouseServiceClient::with_interceptor(channel, interceptor); + + Ok(Self { client, runtime }) }) } @@ -603,7 +627,7 @@ impl LighthouseClient { /// heartbeat_timeout_ms (int): The timeout for heartbeats. #[pyclass] struct LighthouseServer { - lighthouse: Arc, + bind: String, handle: JoinHandle>, _runtime: Runtime, } @@ -631,19 +655,37 @@ impl LighthouseServer { .enable_all() .build()?; - let lighthouse = rt - .block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt { - bind: bind, - min_replicas: min_replicas, - join_timeout_ms: join_timeout_ms, - quorum_tick_ms: quorum_tick_ms, - heartbeat_timeout_ms: heartbeat_timeout_ms, - })) - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let opt = lighthouse::LighthouseOpt { + bind: bind.clone(), + min_replicas, + join_timeout_ms, + quorum_tick_ms, + heartbeat_timeout_ms, + }; + + let listener = rt.block_on(tokio::net::TcpListener::bind(&bind))?; + let bound_sock = listener.local_addr()?; + let incoming = TcpListenerStream::new(listener); + let router = Router::new(opt.clone()); + + let handle = rt.spawn(async move { + tonic::transport::Server::builder() + .add_service(router) + .serve_with_incoming(incoming) + .await + .map_err(|e: tonic::transport::Error| anyhow::anyhow!(e)) + }); + + let host = if bind.starts_with("0.0.0.0") || bind.starts_with("[::]") { + gethostname().to_string_lossy().into_owned() + } else { + bind.rsplit_once(':').map(|(h, _)| h.to_string()).unwrap() + }; + let public_addr = format!("http://{}:{}", host, bound_sock.port()); Ok(Self { - handle: rt.spawn(lighthouse.clone().run()), - lighthouse: lighthouse, + bind: public_addr, + handle, _runtime: rt, }) }) @@ -654,7 +696,7 @@ impl LighthouseServer { /// Returns: /// str: The address of the lighthouse server. fn address(&self) -> PyResult { - Ok(self.lighthouse.address().to_string()) + Ok(self.bind.clone()) } /// shutdown shuts down the lighthouse server. diff --git a/src/lighthouse.rs b/src/lighthouse.rs index a5760032..989cc13c 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -58,6 +58,7 @@ struct State { } pub struct Lighthouse { + id: String, state: Mutex, opt: LighthouseOpt, listener: Mutex>, @@ -83,7 +84,7 @@ impl ChangeLogger { } } -#[derive(StructOpt, Debug)] +#[derive(StructOpt, Debug, Clone)] #[structopt()] pub struct LighthouseOpt { // bind is the address to bind the server to. @@ -261,12 +262,13 @@ fn quorum_compute( } impl Lighthouse { - pub async fn new(opt: LighthouseOpt) -> Result> { + pub async fn new(id: String, opt: LighthouseOpt) -> Result> { let listener = tokio::net::TcpListener::bind(&opt.bind).await?; let (tx, _) = broadcast::channel(16); Ok(Arc::new(Self { + id: id, state: Mutex::new(State { participants: HashMap::new(), channel: tx, @@ -975,7 +977,7 @@ mod tests { quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, }; - let lighthouse = Lighthouse::new(opt).await?; + let lighthouse = Lighthouse::new("".to_string(), opt).await?; let lighthouse_task = tokio::spawn(lighthouse.clone().run()); @@ -1133,7 +1135,7 @@ mod tests { }; // Start the lighthouse service - let lighthouse = Lighthouse::new(opt).await?; + let lighthouse = Lighthouse::new("".to_string(), opt).await?; let lighthouse_task = tokio::spawn(lighthouse.clone().run()); // Create client to interact with lighthouse @@ -1240,7 +1242,7 @@ mod tests { }; // Start the lighthouse service - let lighthouse = Lighthouse::new(opt).await?; + let lighthouse = Lighthouse::new("".to_string(), opt).await?; let lighthouse_task = tokio::spawn(lighthouse.clone().run()); // Create client to interact with lighthouse diff --git a/src/manager.rs b/src/manager.rs index e28cbeb5..79d02657 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -544,13 +544,16 @@ mod tests { #[tokio::test] async fn test_should_commit() -> Result<()> { - let lighthouse = Lighthouse::new(LighthouseOpt { - bind: "[::]:0".to_string(), - join_timeout_ms: 100, - min_replicas: 1, - quorum_tick_ms: 100, - heartbeat_timeout_ms: 5000, - }) + let lighthouse = Lighthouse::new( + "".to_string(), + LighthouseOpt { + bind: "[::]:0".to_string(), + join_timeout_ms: 100, + min_replicas: 1, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + }, + ) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -591,13 +594,16 @@ mod tests { #[tokio::test] async fn test_get_quorum() -> Result<()> { - let lighthouse = Lighthouse::new(LighthouseOpt { - bind: "[::]:0".to_string(), - join_timeout_ms: 100, - min_replicas: 1, - quorum_tick_ms: 100, - heartbeat_timeout_ms: 5000, - }) + let lighthouse = Lighthouse::new( + "".to_string(), + LighthouseOpt { + bind: "[::]:0".to_string(), + join_timeout_ms: 100, + min_replicas: 1, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + }, + ) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -646,13 +652,16 @@ mod tests { #[tokio::test] async fn test_get_quorum_heal_first_step() -> Result<()> { - let lighthouse = Lighthouse::new(LighthouseOpt { - bind: "[::]:0".to_string(), - join_timeout_ms: 100, - min_replicas: 2, - quorum_tick_ms: 100, - heartbeat_timeout_ms: 5000, - }) + let lighthouse = Lighthouse::new( + "".to_string(), + LighthouseOpt { + bind: "[::]:0".to_string(), + join_timeout_ms: 100, + min_replicas: 2, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + }, + ) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -718,13 +727,16 @@ mod tests { #[tokio::test] async fn test_checkpoint_metadata() -> Result<()> { - let lighthouse = Lighthouse::new(LighthouseOpt { - bind: "[::]:0".to_string(), - join_timeout_ms: 100, - min_replicas: 1, - quorum_tick_ms: 100, - heartbeat_timeout_ms: 5000, - }) + let lighthouse = Lighthouse::new( + "".to_string(), + LighthouseOpt { + bind: "[::]:0".to_string(), + join_timeout_ms: 100, + min_replicas: 1, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + }, + ) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); diff --git a/src/router.rs b/src/router.rs new file mode 100644 index 00000000..22d13546 --- /dev/null +++ b/src/router.rs @@ -0,0 +1,106 @@ +use std::{ + convert::Infallible, + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use dashmap::{mapref::entry::Entry, DashMap}; +use futures::FutureExt; +use tonic::{ + body::BoxBody, + codegen::http::{HeaderMap, Request, Response}, + server::NamedService, +}; +use tower::Service; + +use crate::{ + lighthouse::{Lighthouse, LighthouseOpt}, + torchftpb::lighthouse_service_server::LighthouseServiceServer, +}; + +/// Metadata header recognised by both client interceptor and this router. +pub const ROOM_ID_HEADER: &str = "room-id"; + +/// gRPC server for a single room (inner state = `Arc`). +type GrpcSvc = LighthouseServiceServer>; + +#[derive(Clone)] +pub struct Router { + rooms: Arc>>, + tmpl_opt: LighthouseOpt, +} + +impl Router { + pub fn new(tmpl_opt: LighthouseOpt) -> Self { + Self { + rooms: Arc::new(DashMap::new()), + tmpl_opt, + } + } + + fn room_id(hdrs: &HeaderMap) -> &str { + hdrs.get(ROOM_ID_HEADER) + .and_then(|v| v.to_str().ok()) + .unwrap_or("default") + } + + async fn room_service( + rooms: Arc>>, + tmpl: LighthouseOpt, + id: &str, + ) -> Arc { + if let Some(lh) = rooms.get(id) { + return lh.clone(); + } + + let lh = Lighthouse::new(id.to_owned(), tmpl.clone()) + .await + .expect("failed to create Lighthouse"); + + match rooms.entry(id.to_owned()) { + Entry::Occupied(e) => e.get().clone(), + Entry::Vacant(v) => { + v.insert(lh.clone()); + lh + } + } + } +} + +// Tower::Service implementation +impl Service> for Router { + type Response = Response; + type Error = Infallible; + type Future = + Pin> + Send + 'static>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let rooms = self.rooms.clone(); + let tmpl = self.tmpl_opt.clone(); + let room = Self::room_id(req.headers()).to_owned(); + + async move { + let lh = Self::room_service(rooms, tmpl, &room).await; + + let mut svc = LighthouseServiceServer::new(lh); + let resp = svc + .call(req) + .await + .map_err(|_e| -> Infallible { unreachable!() })?; + + Ok(resp) + } + .boxed() + } +} + +// Forward tonic’s NamedService marker +impl NamedService for Router { + const NAME: &'static str = ::NAME; +} diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index 9614d1b0..31eb3481 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -89,6 +89,7 @@ class Quorum: class LighthouseClient: addr: str connect_timeout: timedelta + room_id: Optional[str] = None def quorum( self, diff --git a/torchft/lighthouse_test.py b/torchft/lighthouse_test.py index 067a6222..c26cacd4 100644 --- a/torchft/lighthouse_test.py +++ b/torchft/lighthouse_test.py @@ -4,6 +4,7 @@ import torch.distributed as dist +import torchft.coordination as cd from torchft import Manager, ProcessGroupGloo from torchft._torchft import LighthouseClient, LighthouseServer, Quorum, QuorumMember @@ -155,3 +156,32 @@ def test_heartbeat_round_trip(self) -> None: finally: lighthouse.shutdown() + + def test_multi_room_quorums(self) -> None: + """One server, two logical rooms should yield two isolated quorums.""" + server = cd.LighthouseServer(bind="[::]:0", min_replicas=1) + addr = server.address() + + try: + # Two clients in two independent rooms + cli_a = cd.LighthouseClient(addr, timedelta(seconds=1), room_id="jobA") + cli_b = cd.LighthouseClient(addr, timedelta(seconds=1), room_id="jobB") + + # Explicit heartbeat so each room has one participant + cli_a.heartbeat("a0") + cli_b.heartbeat("b0") + + q_a = cli_a.quorum("a0", timedelta(seconds=1)) + q_b = cli_b.quorum("b0", timedelta(seconds=1)) + + # Both rooms got a quorum-id of 1 but with disjoint members + self.assertEqual(q_a.quorum_id, 1) + self.assertEqual(q_b.quorum_id, 1) + + self.assertEqual(len(q_a.participants), 1) + self.assertEqual(len(q_b.participants), 1) + self.assertEqual(q_a.participants[0].replica_id, "a0") + self.assertEqual(q_b.participants[0].replica_id, "b0") + + finally: + server.shutdown()