From aefc06d10f239663296a02be40b347002d47172b Mon Sep 17 00:00:00 2001 From: MasterPtato Date: Sat, 21 Sep 2024 01:05:39 +0000 Subject: [PATCH] chore(pegboard): get connection working e2e --- lib/pegboard/manager/Cargo.toml | 1 - lib/pegboard/manager/src/main.rs | 29 +-- lib/pegboard/manager/src/utils.rs | 8 +- svc/pkg/cluster/src/util/metrics.rs | 9 +- .../files/pegboard_configure.sh | 2 +- svc/pkg/cluster/src/workflows/server/mod.rs | 190 +++++++++++----- svc/pkg/pegboard/src/lib.rs | 1 - svc/pkg/pegboard/src/utils.rs | 1 - svc/pkg/pegboard/src/workflows/client.rs | 38 +++- svc/pkg/pegboard/standalone/ws/src/lib.rs | 215 ++++++++++++------ 10 files changed, 320 insertions(+), 174 deletions(-) delete mode 100644 svc/pkg/pegboard/src/utils.rs diff --git a/lib/pegboard/manager/Cargo.toml b/lib/pegboard/manager/Cargo.toml index adcaa4ea11..61bedc207e 100644 --- a/lib/pegboard/manager/Cargo.toml +++ b/lib/pegboard/manager/Cargo.toml @@ -13,7 +13,6 @@ indoc = "2.0" lazy_static = "1.4" nix = { version = "0.27", default-features = false, features = ["user", "signal"] } notify = { version = "6.1.1", default-features = false, features = [ "serde" ] } -pnet_datalink = "0.35.0" prometheus = "0.13" rand = "0.8" reqwest = { version = "0.11", features = ["stream"] } diff --git a/lib/pegboard/manager/src/main.rs b/lib/pegboard/manager/src/main.rs index 669c0089bf..04dc525aeb 100644 --- a/lib/pegboard/manager/src/main.rs +++ b/lib/pegboard/manager/src/main.rs @@ -1,7 +1,4 @@ -use std::{ - net::{IpAddr, Ipv4Addr}, - path::Path, -}; +use std::{net::Ipv4Addr, path::Path}; use anyhow::*; use futures_util::StreamExt; @@ -33,7 +30,7 @@ async fn main() -> Result<()> { tokio::spawn(metrics::run_standalone()); let client_id = Uuid::parse_str(&utils::var("CLIENT_ID")?)?; - let network_ip = get_network_ip(&utils::var("NETWORK_INTERFACE")?)?; + let network_ip = utils::var("NETWORK_IP")?.parse::()?; let system = System::new_with_specifics( RefreshKind::new() @@ -76,28 +73,6 @@ async fn main() -> Result<()> { ctx.start(rx).await } -fn get_network_ip(network_interface_name: &str) -> Result { - let network_interface = pnet_datalink::interfaces() - .into_iter() - .find(|iface| iface.name == network_interface_name) - .context(format!( - "network interface not found: {network_interface_name}" - ))?; - let network_ip = network_interface - .ips - .iter() - .find_map(|net| { - if let IpAddr::V4(ip) = net.ip() { - Some(ip) - } else { - None - } - }) - .context("no ipv4 network on interface")?; - - Ok(network_ip) -} - fn init_tracing() { tracing_subscriber::registry() .with( diff --git a/lib/pegboard/manager/src/utils.rs b/lib/pegboard/manager/src/utils.rs index 6d298d64cb..08205cccd6 100644 --- a/lib/pegboard/manager/src/utils.rs +++ b/lib/pegboard/manager/src/utils.rs @@ -120,7 +120,10 @@ pub async fn init_sqlite_schema(pool: &SqlitePool) -> Result<()> { " CREATE TABLE IF NOT EXISTS state ( last_event_idx INTEGER NOT NULL, - last_command_idx INTEGER NOT NULL + last_command_idx INTEGER NOT NULL, + + -- Keeps this table having one row + _persistence BOOLEAN UNIQUE NOT NULL DEFAULT TRUE ) ", )) @@ -129,8 +132,9 @@ pub async fn init_sqlite_schema(pool: &SqlitePool) -> Result<()> { sqlx::query(indoc!( " - INSERT INTO state + INSERT INTO state (last_event_idx, last_command_idx) VALUES (0, 0) + ON CONFLICT DO NOTHING ", )) .execute(&mut *conn) diff --git a/svc/pkg/cluster/src/util/metrics.rs b/svc/pkg/cluster/src/util/metrics.rs index 35cce529ae..5826967d1b 100644 --- a/svc/pkg/cluster/src/util/metrics.rs +++ b/svc/pkg/cluster/src/util/metrics.rs @@ -66,7 +66,14 @@ lazy_static::lazy_static! { ).unwrap(); pub static ref NOMAD_JOIN_DURATION: HistogramVec = register_histogram_vec_with_registry!( "provision_nomad_join_duration", - "Time from installed to nomad joined.", + "Time from installed to Nomad joined.", + &["cluster_id", "datacenter_id", "provider_datacenter_id", "datacenter_name_id"], + PROVISION_BUCKETS.to_vec(), + *REGISTRY, + ).unwrap(); + pub static ref PEGBOARD_JOIN_DURATION: HistogramVec = register_histogram_vec_with_registry!( + "provision_pegboard_join_duration", + "Time from installed to Pegboard joined.", &["cluster_id", "datacenter_id", "provider_datacenter_id", "datacenter_name_id"], PROVISION_BUCKETS.to_vec(), *REGISTRY, diff --git a/svc/pkg/cluster/src/workflows/server/install/install_scripts/files/pegboard_configure.sh b/svc/pkg/cluster/src/workflows/server/install/install_scripts/files/pegboard_configure.sh index 2dcaf5c197..daa51129dd 100644 --- a/svc/pkg/cluster/src/workflows/server/install/install_scripts/files/pegboard_configure.sh +++ b/svc/pkg/cluster/src/workflows/server/install/install_scripts/files/pegboard_configure.sh @@ -295,7 +295,7 @@ ConditionPathExists=/etc/pegboard/ [Service] Environment="CLIENT_ID=___SERVER_ID___" -Environment="NETWORK_INTERFACE=__VLAN_IFACE__" +Environment="NETWORK_IP=___VLAN_IP___" ExecStart=/usr/bin/pegboard Restart=always RestartSec=2 diff --git a/svc/pkg/cluster/src/workflows/server/mod.rs b/svc/pkg/cluster/src/workflows/server/mod.rs index 49e8ee850d..1148ad512f 100644 --- a/svc/pkg/cluster/src/workflows/server/mod.rs +++ b/svc/pkg/cluster/src/workflows/server/mod.rs @@ -5,6 +5,7 @@ use std::{ use chirp_workflow::prelude::*; use rand::Rng; +use serde_json::json; pub(crate) mod dns_create; pub(crate) mod dns_delete; @@ -185,13 +186,24 @@ pub(crate) async fn cluster_server(ctx: &mut WorkflowCtx, input: &Input) -> Glob .send() .await?; - // Create DNS record because the server is already installed - if let PoolType::Gg = input.pool_type { - ctx.workflow(dns_create::Input { - server_id: input.server_id, - }) - .output() - .await?; + match input.pool_type { + // Create DNS record because the server is already installed + PoolType::Gg => { + ctx.workflow(dns_create::Input { + server_id: input.server_id, + }) + .output() + .await?; + } + // Update tags to include pegboard client_id (currently the same as the server_id) + PoolType::Pegboard => { + ctx.activity(UpdateTagsInput { + server_id: input.server_id, + client_id: input.server_id, + }) + .await?; + } + _ => {} } provider_server_workflow_id @@ -252,6 +264,23 @@ pub(crate) async fn cluster_server(ctx: &mut WorkflowCtx, input: &Input) -> Glob .send() .await?; } + Main::PegboardRegistered(_) => { + ctx.activity(SetPegboardClientIdInput { + server_id: input.server_id, + cluster_id: dc.cluster_id, + datacenter_id: dc.datacenter_id, + provider_datacenter_id: dc.provider_datacenter_id.clone(), + datacenter_name_id: dc.name_id.clone(), + client_id: input.server_id, + }) + .await?; + + // Scale to get rid of tainted servers + ctx.signal(crate::workflows::datacenter::Scale {}) + .tag("datacenter_id", input.datacenter_id) + .send() + .await?; + } Main::Drain(_) => { ctx.workflow(drain::Input { datacenter_id: input.datacenter_id, @@ -495,40 +524,37 @@ async fn update_db(ctx: &ActivityCtx, input: &UpdateDbInput) -> GlobalResult<()> ) .await?; - insert_metrics( - input.cluster_id, - input.datacenter_id, - &input.provider_datacenter_id, - &input.datacenter_name_id, - &input.pool_type, - provision_complete_ts, - create_ts, - ) - .await; - - Ok(()) -} - -async fn insert_metrics( - cluster_id: Uuid, - datacenter_id: Uuid, - provider_datacenter_id: &str, - datacenter_name_id: &str, - pool_type: &PoolType, - provision_complete_ts: i64, - create_ts: i64, -) { + // Insert metrics let dt = (provision_complete_ts - create_ts) as f64 / 1000.0; metrics::PROVISION_DURATION .with_label_values(&[ - &cluster_id.to_string(), - &datacenter_id.to_string(), - provider_datacenter_id, - datacenter_name_id, - &pool_type.to_string(), + &input.cluster_id.to_string(), + &input.datacenter_id.to_string(), + &input.provider_datacenter_id, + &input.datacenter_name_id, + &input.pool_type.to_string(), ]) .observe(dt); + + Ok(()) +} + +#[derive(Debug, Serialize, Deserialize, Hash)] +struct UpdateTagsInput { + server_id: Uuid, + client_id: Uuid, +} + +#[activity(UpdateTags)] +async fn update_tags(ctx: &ActivityCtx, input: &UpdateTagsInput) -> GlobalResult<()> { + ctx.update_workflow_tags(&json!({ + "server_id": input.server_id, + "client_id": input.client_id, + })) + .await?; + + Ok(()) } #[derive(Debug, Serialize, Deserialize, Hash)] @@ -575,8 +601,7 @@ async fn set_nomad_node_id(ctx: &ActivityCtx, input: &SetNomadNodeIdInput) -> Gl SET nomad_node_id = $2, nomad_join_ts = $3 - WHERE - server_id = $1 + WHERE server_id = $1 RETURNING nomad_node_id, install_complete_ts ", input.server_id, @@ -591,14 +616,16 @@ async fn set_nomad_node_id(ctx: &ActivityCtx, input: &SetNomadNodeIdInput) -> Gl // Insert metrics if let Some(install_complete_ts) = install_complete_ts { - insert_nomad_metrics( - input.cluster_id, - input.datacenter_id, - &input.provider_datacenter_id, - &input.datacenter_name_id, - nomad_join_ts, - install_complete_ts, - ); + let dt = (nomad_join_ts - install_complete_ts) as f64 / 1000.0; + + metrics::NOMAD_JOIN_DURATION + .with_label_values(&[ + &input.cluster_id.to_string(), + &input.datacenter_id.to_string(), + &input.provider_datacenter_id, + &input.datacenter_name_id, + ]) + .observe(dt); } else { tracing::warn!("missing install_complete_ts"); } @@ -606,24 +633,58 @@ async fn set_nomad_node_id(ctx: &ActivityCtx, input: &SetNomadNodeIdInput) -> Gl Ok(()) } -fn insert_nomad_metrics( +#[derive(Debug, Serialize, Deserialize, Hash)] +struct SetPegboardClientIdInput { + server_id: Uuid, cluster_id: Uuid, datacenter_id: Uuid, - provider_datacenter_id: &str, - datacenter_name_id: &str, - nomad_join_ts: i64, - install_complete_ts: i64, -) { - let dt = (nomad_join_ts - install_complete_ts) as f64 / 1000.0; - - metrics::NOMAD_JOIN_DURATION - .with_label_values(&[ - &cluster_id.to_string(), - &datacenter_id.to_string(), - provider_datacenter_id, - datacenter_name_id, - ]) - .observe(dt); + provider_datacenter_id: String, + datacenter_name_id: String, + client_id: Uuid, +} + +#[activity(SetPegboardClientId)] +async fn set_pegboard_client_id( + ctx: &ActivityCtx, + input: &SetPegboardClientIdInput, +) -> GlobalResult<()> { + let pegboard_join_ts = util::timestamp::now(); + + let (old_pegboard_client_id, install_complete_ts) = sql_fetch_one!( + [ctx, (Option, Option)] + " + UPDATE db_cluster.servers + SET + pegboard_client_id = $2 + WHERE server_id = $1 + RETURNING pegboard_client_id, install_complete_ts + ", + input.server_id, + &input.client_id, + ) + .await?; + + if let Some(old_pegboard_client_id) = old_pegboard_client_id { + tracing::warn!(%old_pegboard_client_id, "pegboard client id was already set"); + } + + // Insert metrics + if let Some(install_complete_ts) = install_complete_ts { + let dt = (pegboard_join_ts - install_complete_ts) as f64 / 1000.0; + + metrics::PEGBOARD_JOIN_DURATION + .with_label_values(&[ + &input.cluster_id.to_string(), + &input.datacenter_id.to_string(), + &input.provider_datacenter_id, + &input.datacenter_name_id, + ]) + .observe(dt); + } else { + tracing::warn!("missing install_complete_ts"); + } + + Ok(()) } #[derive(Debug, Serialize, Deserialize, Hash)] @@ -745,7 +806,11 @@ impl CustomListener for State { */ async fn listen(&self, ctx: &ListenCtx) -> WorkflowResult { // Determine which signals to listen to - let mut signals = vec![Destroy::NAME, NomadRegistered::NAME]; + let mut signals = vec![ + Destroy::NAME, + NomadRegistered::NAME, + pegboard::workflows::client::Registered::NAME, + ]; if !self.draining { signals.push(Drain::NAME); @@ -823,4 +888,5 @@ join_signal!(Main { DnsDelete, Destroy, NomadRegistered, + PegboardRegistered(pegboard::workflows::client::Registered), }); diff --git a/svc/pkg/pegboard/src/lib.rs b/svc/pkg/pegboard/src/lib.rs index 771bbd8d0a..2390673b5d 100644 --- a/svc/pkg/pegboard/src/lib.rs +++ b/svc/pkg/pegboard/src/lib.rs @@ -3,7 +3,6 @@ use chirp_workflow::prelude::*; pub mod ops; pub mod protocol; pub mod types; -pub mod utils; pub mod workflows; pub fn registry() -> WorkflowResult { diff --git a/svc/pkg/pegboard/src/utils.rs b/svc/pkg/pegboard/src/utils.rs deleted file mode 100644 index 8b13789179..0000000000 --- a/svc/pkg/pegboard/src/utils.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/svc/pkg/pegboard/src/workflows/client.rs b/svc/pkg/pegboard/src/workflows/client.rs index 0cf741d3c5..0a318c3898 100644 --- a/svc/pkg/pegboard/src/workflows/client.rs +++ b/svc/pkg/pegboard/src/workflows/client.rs @@ -14,6 +14,9 @@ pub struct Input { #[workflow] pub async fn pegboard_client(ctx: &mut WorkflowCtx, input: &Input) -> GlobalResult<()> { + // Whatever started this client should be listening for this + ctx.signal(Registered { }).tag("client_id", input.client_id).send().await?; + ctx.repeat(|ctx| { let client_id = input.client_id; @@ -136,6 +139,14 @@ pub async fn pegboard_client(ctx: &mut WorkflowCtx, input: &Input) -> GlobalResu ) .await?; + // Close websocket connection + ctx.msg(CloseWs { + client_id: input.client_id, + }) + .tags(json!({})) + .send() + .await?; + Ok(()) } @@ -163,10 +174,10 @@ async fn process_init( " UPDATE db_pegboard.clients SET - cpu = $2 AND + cpu = $2, memory = $3 WHERE client_id = $1 - RETURNING last_event_idx + RETURNING last_event_idx ", input.client_id, input.system.cpu as i64, @@ -330,6 +341,10 @@ pub async fn handle_commands( client_id: Uuid, commands: Vec, ) -> GlobalResult<()> { + if commands.is_empty() { + return Ok(()); + } + let raw_commands = commands .iter() .map(protocol::Raw::new) @@ -402,7 +417,7 @@ async fn insert_commands(ctx: &ActivityCtx, input: &InsertCommandsInput) -> Glob [ctx, (i64,)] " WITH - last_idx AS ( + last_event_idx(idx) AS ( UPDATE db_pegboard.clients SET last_command_idx = last_command_idx + 1 WHERE client_id = $1 @@ -415,12 +430,12 @@ async fn insert_commands(ctx: &ActivityCtx, input: &InsertCommandsInput) -> Glob index, create_ts ) - SELECT $1, p.payload, last_idx.last_command_idx + p.index - 1, $3 - FROM last_idx + SELECT $1, p.payload, l.idx + p.index - 1, $3 + FROM last_event_idx AS l CROSS JOIN UNNEST($2) WITH ORDINALITY AS p(payload, index) RETURNING 1 ) - SELECT last_event_idx FROM last_idx + SELECT idx FROM last_event_idx ", input.client_id, &input.commands, @@ -468,7 +483,7 @@ async fn fetch_all_containers( [ctx, (Uuid,)] " SELECT container_id - FROM db_pegboard.clients + FROM db_pegboard.containers WHERE client_id = $1 AND stopping_ts IS NULL AND @@ -485,12 +500,21 @@ async fn fetch_all_containers( Ok(container_ids) } +#[signal("pegboard_client_registered")] +pub struct Registered { +} + #[message("pegboard_client_to_ws")] pub struct ToWs { pub client_id: Uuid, pub inner: protocol::ToClient, } +#[message("pegboard_client_close_ws")] +pub struct CloseWs { + pub client_id: Uuid, +} + #[signal("pegboard_container_state_update")] pub struct ContainerStateUpdate { pub state: protocol::ContainerState, diff --git a/svc/pkg/pegboard/standalone/ws/src/lib.rs b/svc/pkg/pegboard/standalone/ws/src/lib.rs index 3690e3c530..ac1a100d10 100644 --- a/svc/pkg/pegboard/standalone/ws/src/lib.rs +++ b/svc/pkg/pegboard/standalone/ws/src/lib.rs @@ -1,4 +1,12 @@ -use std::{collections::HashMap, net::SocketAddr, sync::Arc}; +use std::{ + collections::HashMap, + net::SocketAddr, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; use chirp_workflow::prelude::*; use futures_util::{stream::SplitSink, SinkExt, StreamExt}; @@ -11,12 +19,15 @@ use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream}; use pegboard::protocol; +const UPDATE_PING_INTERVAL: Duration = Duration::from_secs(3); + struct Connection { protocol_version: u16, - tx: Arc, Message>>>, + tx: Mutex, Message>>, + update_ping: AtomicBool, } -type Connections = HashMap; +type Connections = HashMap>; #[tracing::instrument(skip_all)] pub async fn run_from_env(pools: rivet_pools::Pools) -> GlobalResult<()> { @@ -33,7 +44,8 @@ pub async fn run_from_env(pools: rivet_pools::Pools) -> GlobalResult<()> { tokio::try_join!( socket_thread(&ctx, conns.clone()), - signal_thread(&ctx, conns.clone()), + msg_thread(&ctx, conns.clone()), + update_ping_thread(&ctx, conns.clone()), )?; Ok(()) @@ -70,21 +82,23 @@ async fn handle_connection( Ok(x) => x, Err(err) => { tracing::error!(?addr, "{err}"); - return Err(err); + return; } }; - // Handle result for cleanup - match handle_connection_inner(&ctx, conns.clone(), ws_stream, protocol_version, client_id) - .await + if let Err(err) = + handle_connection_inner(&ctx, conns.clone(), ws_stream, protocol_version, client_id) + .await { - Err(err) => { - // Clean up connection - conns.write().await.remove(&client_id); + tracing::error!(?addr, "{err}"); + } - Err(err) + // Clean up + let conn = conns.write().await.remove(&client_id); + if let Some(conn) = conn { + if let Err(err) = conn.tx.lock().await.send(Message::Close(None)).await { + tracing::error!(?addr, "failed closing socket: {err}"); } - x => x, } }); } @@ -123,16 +137,16 @@ async fn handle_connection_inner( ) -> GlobalResult<()> { let (tx, mut rx) = ws_stream.split(); - let tx = Arc::new(Mutex::new(tx)); - let conn = Connection { + let conn = Arc::new(Connection { protocol_version, - tx: tx.clone(), - }; + tx: Mutex::new(tx), + update_ping: AtomicBool::new(false), + }); // Store connection { let mut conns = conns.write().await; - if let Some(old_conn) = conns.insert(client_id, conn) { + if let Some(old_conn) = conns.insert(client_id, conn.clone()) { tracing::warn!( ?client_id, "client already connected, closing old connection" @@ -158,8 +172,8 @@ async fn handle_connection_inner( .await?; } Message::Ping(_) => { - update_ping(ctx, client_id).await?; - tx.lock().await.send(Message::Pong(Vec::new())).await?; + conn.update_ping.store(true, Ordering::Relaxed); + conn.tx.lock().await.send(Message::Pong(Vec::new())).await?; } Message::Close(_) => { bail!(format!("socket closed {client_id}")); @@ -179,24 +193,47 @@ async fn handle_connection_inner( async fn upsert_client(ctx: &StandaloneCtx, client_id: Uuid) -> GlobalResult<()> { // Inserting before creating the workflow prevents a race condition with using select + insert instead - let inserted = sql_fetch_optional!( - [ctx, (i64,)] + let (exists, deleted) = sql_fetch_one!( + [ctx, (bool, bool)] " - INSERT INTO db_pegboard.clients (client_id, create_ts, last_ping_ts) - VALUES ($1, $2, $2) - ON CONFLICT (client_id) DO NOTHING - RETURNING 1 + WITH + select_exists AS ( + SELECT 1 + FROM db_pegboard.clients + WHERE client_id = $1 + ), + select_deleted AS ( + SELECT 1 + FROM db_pegboard.clients + WHERE + client_id = $1 AND + delete_ts IS NOT NULL + ), + insert_client AS ( + INSERT INTO db_pegboard.clients (client_id, create_ts, last_ping_ts) + VALUES ($1, $2, $2) + ON CONFLICT (client_id) + DO UPDATE + SET delete_ts = NULL + RETURNING 1 + ) + SELECT + EXISTS(SELECT 1 FROM select_exists) AS exists, + EXISTS(SELECT 1 FROM select_deleted) AS deleted ", client_id, util::timestamp::now(), ) - .await? - .is_some(); + .await?; - // If the row was inserted, spawn a new client workflow - if inserted { - tracing::info!(?client_id, "new client"); + if deleted { + tracing::warn!(?client_id, "client was previously deleted"); + } + if exists == deleted { + tracing::info!(?client_id, "new client"); + + // Spawn a new client workflow ctx.workflow(pegboard::workflows::client::Input { client_id }) .tag("client_id", client_id) .dispatch() @@ -206,56 +243,92 @@ async fn upsert_client(ctx: &StandaloneCtx, client_id: Uuid) -> GlobalResult<()> Ok(()) } -async fn update_ping(ctx: &StandaloneCtx, client_id: Uuid) -> GlobalResult<()> { - let (deleted,) = sql_fetch_one!( - [ctx, (bool,)] - " - WITH - select_old AS ( - SELECT delete_ts IS NOT NULL AS deleted - FROM db_pegboard.clients - WHERE client_id = $1 - ) - UPDATE db_pegboard.clients - SET - last_ping_ts = $2 AND - delete_ts = NULL - WHERE client_id = $1 - RETURNING (SELECT * FROM select_old) - ", - client_id, - util::timestamp::now(), - ) - .await?; +/// Updates the ping of all clients requesting a ping update at once. +async fn update_ping_thread( + ctx: &StandaloneCtx, + conns: Arc>, +) -> GlobalResult<()> { + loop { + tokio::time::sleep(UPDATE_PING_INTERVAL).await; - if deleted { - tracing::warn!(?client_id, "deleted client reconnected"); - } + let client_ids = { + let conns = conns.read().await; - Ok(()) + // Select all clients that required a ping update + conns + .iter() + .filter_map(|(client_id, conn)| { + conn.update_ping + .swap(false, Ordering::Relaxed) + .then_some(*client_id) + }) + .collect::>() + }; + + if client_ids.is_empty() { + continue; + } + + sql_execute!( + [ctx] + " + UPDATE db_pegboard.clients + SET last_ping_ts = $2 + WHERE client_id = ANY($1) + RETURNING 1 + ", + client_ids, + util::timestamp::now(), + ) + .await?; + } } -async fn signal_thread(ctx: &StandaloneCtx, conns: Arc>) -> GlobalResult<()> { +async fn msg_thread(ctx: &StandaloneCtx, conns: Arc>) -> GlobalResult<()> { // Listen for commands from client workflows let mut sub = ctx .subscribe::(&json!({})) .await?; + let mut close_sub = ctx + .subscribe::(&json!({})) + .await?; loop { - let msg = sub.next().await?; - - { - let conns = conns.read().await; - - // Send command to socket - if let Some(conn) = conns.get(&msg.client_id) { - let buf = msg.inner.serialize(conn.protocol_version)?; - conn.tx.lock().await.send(Message::Binary(buf)).await?; - } else { - tracing::debug!( - client_id=?msg.client_id, - "received command for client that isn't connected, ignoring" - ); + tokio::select! { + msg = sub.next() => { + let msg = msg?; + + { + let conns = conns.read().await; + + // Send command to socket + if let Some(conn) = conns.get(&msg.client_id) { + let buf = msg.inner.serialize(conn.protocol_version)?; + conn.tx.lock().await.send(Message::Binary(buf)).await?; + } else { + tracing::debug!( + client_id=?msg.client_id, + "received command for client that isn't connected, ignoring" + ); + } + } + } + msg = close_sub.next() => { + let msg = msg?; + + { + let conns = conns.read().await; + + // Close socket + if let Some(conn) = conns.get(&msg.client_id) { + conn.tx.lock().await.send(Message::Close(None)).await?; + } else { + tracing::debug!( + client_id=?msg.client_id, + "received close command for client that isn't connected, ignoring" + ); + } + } } } }