From 48a2fd0795d788768c0e1ff215b743a531e982bb Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Fri, 7 Nov 2025 17:43:13 -0800 Subject: [PATCH 1/2] Don't retry worker shutdown calls --- crates/client/src/lib.rs | 4 + crates/client/src/retry.rs | 9 +- crates/sdk-core/src/worker/client.rs | 12 +- crates/sdk-core/src/worker/mod.rs | 12 +- .../sdk-core/tests/common/fake_grpc_server.rs | 103 ++++++++++++++++ crates/sdk-core/tests/common/mod.rs | 11 +- .../tests/integ_tests/client_tests.rs | 113 ++---------------- .../tests/integ_tests/worker_tests.rs | 37 +++++- 8 files changed, 181 insertions(+), 120 deletions(-) create mode 100644 crates/sdk-core/tests/common/fake_grpc_server.rs diff --git a/crates/client/src/lib.rs b/crates/client/src/lib.rs index 5f3e9d6ec..da3e37b95 100644 --- a/crates/client/src/lib.rs +++ b/crates/client/src/lib.rs @@ -329,6 +329,10 @@ pub struct NoRetryOnMatching { pub predicate: fn(&tonic::Status) -> bool, } +/// A request extension that forces overriding the current retry policy of the [RetryClient]. +#[derive(Clone, Debug)] +pub struct RetryConfigForCall(pub RetryConfig); + impl Debug for ClientTlsConfig { // Intentionally omit details here since they could leak a key if ever printed fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { diff --git a/crates/client/src/retry.rs b/crates/client/src/retry.rs index 053f80fc2..f5794ef36 100644 --- a/crates/client/src/retry.rs +++ b/crates/client/src/retry.rs @@ -1,6 +1,7 @@ use crate::{ ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT, IsWorkerTaskLongPoll, MESSAGE_TOO_LARGE_KEY, - NamespacedClient, NoRetryOnMatching, Result, RetryConfig, raw::IsUserLongPoll, + NamespacedClient, NoRetryOnMatching, Result, RetryConfig, RetryConfigForCall, + raw::IsUserLongPoll, }; use backoff::{Clock, SystemClock, backoff::Backoff, exponential::ExponentialBackoff}; use futures_retry::{ErrorHandler, FutureRetry, RetryPolicy}; @@ -60,6 +61,7 @@ impl RetryClient { ) -> CallInfo { let mut call_type = CallType::Normal; let mut retry_short_circuit = None; + let mut retry_cfg_override = None; if let Some(r) = request.as_ref() { let ext = r.extensions(); if ext.get::().is_some() { @@ -69,8 +71,11 @@ impl RetryClient { } retry_short_circuit = ext.get::().cloned(); + retry_cfg_override = ext.get::().cloned(); } - let retry_cfg = if call_type == CallType::TaskLongPoll { + let retry_cfg = if let Some(ovr) = retry_cfg_override { + ovr.0 + } else if call_type == CallType::TaskLongPoll { RetryConfig::task_poll_retry_policy() } else { (*self.retry_config).clone() diff --git a/crates/sdk-core/src/worker/client.rs b/crates/sdk-core/src/worker/client.rs index effa9dfa1..96f0514f5 100644 --- a/crates/sdk-core/src/worker/client.rs +++ b/crates/sdk-core/src/worker/client.rs @@ -11,7 +11,7 @@ use std::{ }; use temporalio_client::{ Client, ClientWorkerSet, IsWorkerTaskLongPoll, Namespace, NamespacedClient, NoRetryOnMatching, - RetryClient, SharedReplaceableClient, WorkflowService, + RetryClient, RetryConfig, RetryConfigForCall, SharedReplaceableClient, WorkflowService, }; use temporalio_common::{ protos::{ @@ -697,16 +697,20 @@ impl WorkerClient for WorkerClientBag { w.status = WorkerStatus::Shutdown.into(); self.set_heartbeat_client_fields(w); } - let request = ShutdownWorkerRequest { + let mut request = ShutdownWorkerRequest { namespace: self.namespace.clone(), identity: self.identity.clone(), sticky_task_queue, reason: "graceful shutdown".to_string(), worker_heartbeat: final_heartbeat, - }; + } + .into_request(); + request + .extensions_mut() + .insert(RetryConfigForCall(RetryConfig::no_retries())); Ok( - WorkflowService::shutdown_worker(&mut self.client.clone(), request.into_request()) + WorkflowService::shutdown_worker(&mut self.client.clone(), request) .await? .into_inner(), ) diff --git a/crates/sdk-core/src/worker/mod.rs b/crates/sdk-core/src/worker/mod.rs index 4f9c67ce7..5721079b1 100644 --- a/crates/sdk-core/src/worker/mod.rs +++ b/crates/sdk-core/src/worker/mod.rs @@ -361,7 +361,12 @@ impl Worker { #[cfg(test)] pub(crate) fn new_test(config: WorkerConfig, client: impl WorkerClient + 'static) -> Self { - Self::new(config, None, Arc::new(client), None, None).unwrap() + let sticky_queue_name = if config.max_cached_workflows > 0 { + Some(format!("sticky-{}", config.task_queue)) + } else { + None + }; + Self::new(config, sticky_queue_name, Arc::new(client), None, None).unwrap() } pub(crate) fn new_with_pollers( @@ -698,7 +703,10 @@ impl Worker { tonic::Code::Unimplemented | tonic::Code::Unavailable ) => { - warn!("Failed to shutdown sticky queue {:?}", err); + warn!( + "shutdown_worker rpc errored during worker shutdown: {:?}", + err + ); } _ => {} } diff --git a/crates/sdk-core/tests/common/fake_grpc_server.rs b/crates/sdk-core/tests/common/fake_grpc_server.rs new file mode 100644 index 000000000..8e7e0ee06 --- /dev/null +++ b/crates/sdk-core/tests/common/fake_grpc_server.rs @@ -0,0 +1,103 @@ +use futures_util::future::{BoxFuture, FutureExt}; +use std::{ + convert::Infallible, + task::{Context, Poll}, +}; +use tokio::{ + net::TcpListener, + sync::{mpsc::UnboundedSender, oneshot}, +}; +use tonic::{ + body::Body, codegen::Service, codegen::http::Response, server::NamedService, transport::Server, +}; + +#[derive(Clone)] +pub(crate) struct GenericService { + pub header_to_parse: &'static str, + pub header_tx: UnboundedSender, + pub response_maker: F, +} +impl Service> for GenericService +where + F: FnMut(tonic::codegen::http::Request) -> BoxFuture<'static, Response>, +{ + type Response = Response; + type Error = Infallible; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: tonic::codegen::http::Request) -> Self::Future { + self.header_tx + .send( + String::from_utf8_lossy( + req.headers() + .get(self.header_to_parse) + .map(|hv| hv.as_bytes()) + .unwrap_or_default(), + ) + .to_string(), + ) + .unwrap(); + let r = (self.response_maker)(req); + async move { Ok(r.await) }.boxed() + } +} +impl NamedService for GenericService { + const NAME: &'static str = "temporal.api.workflowservice.v1.WorkflowService"; +} + +pub(crate) struct FakeServer { + pub addr: std::net::SocketAddr, + shutdown_tx: oneshot::Sender<()>, + pub header_rx: tokio::sync::mpsc::UnboundedReceiver, + pub server_handle: tokio::task::JoinHandle<()>, +} + +pub(crate) async fn fake_server(response_maker: F) -> FakeServer +where + F: FnMut(tonic::codegen::http::Request) -> BoxFuture<'static, Response> + + Clone + + Send + + Sync + + 'static, +{ + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let (header_tx, header_rx) = tokio::sync::mpsc::unbounded_channel(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + Server::builder() + .add_service(GenericService { + header_to_parse: "grpc-timeout", + header_tx, + response_maker, + }) + .serve_with_incoming_shutdown( + tokio_stream::wrappers::TcpListenerStream::new(listener), + async { + shutdown_rx.await.ok(); + }, + ) + .await + .unwrap(); + }); + + FakeServer { + addr, + shutdown_tx, + header_rx, + server_handle, + } +} + +impl FakeServer { + pub(crate) async fn shutdown(self) { + self.shutdown_tx.send(()).unwrap(); + self.server_handle.await.unwrap(); + } +} diff --git a/crates/sdk-core/tests/common/mod.rs b/crates/sdk-core/tests/common/mod.rs index f87fe5588..5a1c400da 100644 --- a/crates/sdk-core/tests/common/mod.rs +++ b/crates/sdk-core/tests/common/mod.rs @@ -1,6 +1,7 @@ //! Common integration testing utilities //! These utilities are specific to integration tests and depend on the full temporal-client stack. +pub(crate) mod fake_grpc_server; pub(crate) mod http_proxy; pub(crate) mod workflows; @@ -240,11 +241,11 @@ struct InitializedWorker { impl CoreWfStarter { pub(crate) fn new(test_name: &str) -> Self { init_integ_telem(); - Self::_new(test_name, None, None) + Self::new_with_overrides(test_name, None, None) } pub(crate) fn new_with_runtime(test_name: &str, runtime: CoreRuntime) -> Self { - Self::_new(test_name, Some(runtime), None) + Self::new_with_overrides(test_name, Some(runtime), None) } /// Targets cloud if the required env vars are present. Otherwise, local server (but only if @@ -260,7 +261,7 @@ impl CoreWfStarter { check_mlsv = true; None }; - let mut s = Self::_new(test_name, None, client); + let mut s = Self::new_with_overrides(test_name, None, client); if check_mlsv && !version_req.is_empty() { let clustinfo = (*s.get_client().await) @@ -291,7 +292,7 @@ impl CoreWfStarter { Some(s) } - fn _new( + pub(crate) fn new_with_overrides( test_name: &str, runtime_override: Option, client_override: Option>, @@ -450,7 +451,7 @@ impl CoreWfStarter { let rt = if let Some(ref rto) = self.runtime_override { rto } else { - INTEG_TESTS_RT.get().unwrap() + init_integ_telem().unwrap() }; let cfg = self .worker_config diff --git a/crates/sdk-core/tests/integ_tests/client_tests.rs b/crates/sdk-core/tests/integ_tests/client_tests.rs index 0da32a024..ad080703d 100644 --- a/crates/sdk-core/tests/integ_tests/client_tests.rs +++ b/crates/sdk-core/tests/integ_tests/client_tests.rs @@ -1,17 +1,20 @@ -use crate::common::{CoreWfStarter, NAMESPACE, get_integ_server_options, http_proxy::HttpProxy}; +use crate::common::{ + CoreWfStarter, NAMESPACE, + fake_grpc_server::{GenericService, fake_server}, + get_integ_server_options, + http_proxy::HttpProxy, +}; use assert_matches::assert_matches; -use futures_util::{FutureExt, future::BoxFuture}; +use futures_util::FutureExt; use http_body_util::Full; use prost::Message; use std::{ collections::HashMap, - convert::Infallible, env, sync::{ Arc, atomic::{AtomicUsize, Ordering}, }, - task::{Context, Poll}, time::Duration, }; use temporalio_client::{ @@ -27,16 +30,9 @@ use temporalio_common::protos::temporal::api::{ }; #[cfg(unix)] use tokio::net::UnixListener; -use tokio::{ - net::TcpListener, - sync::{mpsc::UnboundedSender, oneshot}, -}; +use tokio::{net::TcpListener, sync::oneshot}; use tonic::{ - Code, IntoRequest, Request, Status, - body::Body, - codegen::{Service, http::Response}, - server::NamedService, - transport::Server, + Code, IntoRequest, Request, Status, body::Body, codegen::http::Response, transport::Server, }; use tracing::info; @@ -111,97 +107,6 @@ async fn per_call_timeout_respected_one_call() { ); } -#[derive(Clone)] -struct GenericService { - header_to_parse: &'static str, - header_tx: UnboundedSender, - response_maker: F, -} -impl Service> for GenericService -where - F: FnMut(tonic::codegen::http::Request) -> BoxFuture<'static, Response>, -{ - type Response = Response; - type Error = Infallible; - type Future = BoxFuture<'static, Result>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: tonic::codegen::http::Request) -> Self::Future { - self.header_tx - .send( - String::from_utf8_lossy( - req.headers() - .get(self.header_to_parse) - .map(|hv| hv.as_bytes()) - .unwrap_or_default(), - ) - .to_string(), - ) - .unwrap(); - let r = (self.response_maker)(req); - async move { Ok(r.await) }.boxed() - } -} -impl NamedService for GenericService { - const NAME: &'static str = "temporal.api.workflowservice.v1.WorkflowService"; -} - -struct FakeServer { - addr: std::net::SocketAddr, - shutdown_tx: oneshot::Sender<()>, - header_rx: tokio::sync::mpsc::UnboundedReceiver, - pub server_handle: tokio::task::JoinHandle<()>, -} - -async fn fake_server(response_maker: F) -> FakeServer -where - F: FnMut(tonic::codegen::http::Request) -> BoxFuture<'static, Response> - + Clone - + Send - + Sync - + 'static, -{ - let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); - let (header_tx, header_rx) = tokio::sync::mpsc::unbounded_channel(); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - let server_handle = tokio::spawn(async move { - Server::builder() - .add_service(GenericService { - header_to_parse: "grpc-timeout", - header_tx, - response_maker, - }) - .serve_with_incoming_shutdown( - tokio_stream::wrappers::TcpListenerStream::new(listener), - async { - shutdown_rx.await.ok(); - }, - ) - .await - .unwrap(); - }); - - FakeServer { - addr, - shutdown_tx, - header_rx, - server_handle, - } -} - -impl FakeServer { - async fn shutdown(self) { - self.shutdown_tx.send(()).unwrap(); - self.server_handle.await.unwrap(); - } -} - #[tokio::test] async fn timeouts_respected_one_call_fake_server() { let mut fs = fake_server(|_| async { Response::new(Body::empty()) }.boxed()).await; diff --git a/crates/sdk-core/tests/integ_tests/worker_tests.rs b/crates/sdk-core/tests/integ_tests/worker_tests.rs index 39551d108..097b4e088 100644 --- a/crates/sdk-core/tests/integ_tests/worker_tests.rs +++ b/crates/sdk-core/tests/integ_tests/worker_tests.rs @@ -1,7 +1,7 @@ use crate::{ common::{ - CoreWfStarter, get_integ_runtime_options, get_integ_server_options, - get_integ_telem_options, mock_sdk_cfg, + CoreWfStarter, fake_grpc_server::fake_server, get_integ_runtime_options, + get_integ_server_options, get_integ_telem_options, mock_sdk_cfg, }, shared_tests, }; @@ -11,7 +11,10 @@ use std::{ cell::Cell, sync::{ Arc, Mutex, - atomic::{AtomicBool, Ordering::Relaxed}, + atomic::{ + AtomicBool, AtomicU8, + Ordering::{self, Relaxed}, + }, }, time::Duration, }; @@ -861,3 +864,31 @@ async fn test_custom_slot_supplier_simple() { "Number of reserves should equal number of releases" ); } + +#[tokio::test] +async fn shutdown_worker_not_retried() { + let shutdown_call_count = Arc::new(AtomicU8::new(0)); + let scc = shutdown_call_count.clone(); + let fs = fake_server(move |req| { + if req.uri().to_string().contains("ShutdownWorker") { + scc.fetch_add(1, Ordering::Relaxed); + } + let s = tonic::Status::new(tonic::Code::Unknown, "bla").into_http(); + async { s }.boxed() + }) + .await; + + let mut opts = get_integ_server_options(); + let uri = format!("http://localhost:{}", fs.addr.port()) + .parse() + .unwrap(); + opts.target_url = uri; + opts.skip_get_system_info = true; + let client = opts.connect("ns", None).await.unwrap(); + + let wf_type = "shutdown_worker_not_retried"; + let mut starter = CoreWfStarter::new_with_overrides(wf_type, None, Some(client)); + let worker = starter.get_worker().await; + drain_pollers_and_shutdown(&worker).await; + assert_eq!(shutdown_call_count.load(Ordering::Relaxed), 1); +} From 2955b7d5870aba082faad98cecbc9505c90e0581 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Fri, 7 Nov 2025 20:06:25 -0800 Subject: [PATCH 2/2] Fix possible shutdown hang --- crates/sdk-core/src/worker/heartbeat.rs | 18 +++++++++++------- crates/sdk-core/src/worker/mod.rs | 21 ++++++++++----------- crates/sdk-core/src/worker/workflow/mod.rs | 15 ++++++++------- 3 files changed, 29 insertions(+), 25 deletions(-) diff --git a/crates/sdk-core/src/worker/heartbeat.rs b/crates/sdk-core/src/worker/heartbeat.rs index 6beaeba3c..400c202c3 100644 --- a/crates/sdk-core/src/worker/heartbeat.rs +++ b/crates/sdk-core/src/worker/heartbeat.rs @@ -2,7 +2,7 @@ use crate::{ WorkerClient, worker::{TaskPollers, WorkerTelemetry}, }; -use parking_lot::Mutex; +use parking_lot::RwLock; use std::{collections::HashMap, sync::Arc, time::Duration}; use temporalio_client::SharedNamespaceWorkerTrait; use temporalio_common::{ @@ -20,7 +20,7 @@ pub(crate) type HeartbeatFn = Arc WorkerHeartbeat + Send + Sync>; /// worker heartbeats to the server. This invokes callbacks on all workers in the same process that /// share the same namespace. pub(crate) struct SharedNamespaceWorker { - heartbeat_map: Arc>>, + heartbeat_map: Arc>>, namespace: String, cancel: CancellationToken, } @@ -63,7 +63,7 @@ impl SharedNamespaceWorker { let client_clone = client; let namespace_clone = namespace.clone(); - let heartbeat_map = Arc::new(Mutex::new(HashMap::::new())); + let heartbeat_map = Arc::new(RwLock::new(HashMap::::new())); let heartbeat_map_clone = heartbeat_map.clone(); tokio::spawn(async move { @@ -93,7 +93,10 @@ impl SharedNamespaceWorker { tokio::select! { _ = ticker.tick() => { let mut hb_to_send = Vec::new(); - for (_instance_key, heartbeat_callback) in heartbeat_map_clone.lock().iter() { + let hb_callbacks = { + heartbeat_map_clone.read().values().cloned().collect::>() + }; + for heartbeat_callback in hb_callbacks { let mut heartbeat = heartbeat_callback(); // All of these heartbeat details rely on a client. To avoid circular // dependencies, this must be populated from within SharedNamespaceWorker @@ -135,11 +138,12 @@ impl SharedNamespaceWorkerTrait for SharedNamespaceWorker { fn register_callback(&self, worker_instance_key: Uuid, heartbeat_callback: HeartbeatFn) { self.heartbeat_map - .lock() + .write() .insert(worker_instance_key, heartbeat_callback); } + fn unregister_callback(&self, worker_instance_key: Uuid) -> (Option, bool) { - let mut heartbeat_map = self.heartbeat_map.lock(); + let mut heartbeat_map = self.heartbeat_map.write(); let heartbeat_callback = heartbeat_map.remove(&worker_instance_key); if heartbeat_map.is_empty() { self.cancel.cancel(); @@ -148,7 +152,7 @@ impl SharedNamespaceWorkerTrait for SharedNamespaceWorker { } fn num_workers(&self) -> usize { - self.heartbeat_map.lock().len() + self.heartbeat_map.read().len() } } diff --git a/crates/sdk-core/src/worker/mod.rs b/crates/sdk-core/src/worker/mod.rs index 5721079b1..1ba8fb7ad 100644 --- a/crates/sdk-core/src/worker/mod.rs +++ b/crates/sdk-core/src/worker/mod.rs @@ -53,7 +53,7 @@ use anyhow::bail; use crossbeam_utils::atomic::AtomicCell; use futures_util::{StreamExt, stream}; use gethostname::gethostname; -use parking_lot::{Mutex, RwLock}; +use parking_lot::RwLock; use slot_provider::SlotProvider; use std::{ convert::TryInto, @@ -137,7 +137,7 @@ pub struct Worker { /// Used to track worker client client_worker_registrator: Arc, /// Status of the worker - status: Arc>, + status: Arc>, } struct AllPermitsTracker { @@ -256,7 +256,7 @@ impl WorkerTrait for Worker { } self.shutdown_token.cancel(); { - *self.status.lock() = WorkerStatus::ShuttingDown; + *self.status.write() = WorkerStatus::ShuttingDown; } // First, unregister worker from the client if !self.client_worker_registrator.shared_namespace_worker { @@ -276,12 +276,11 @@ impl WorkerTrait for Worker { if !self.workflows.ever_polled() { self.local_act_mgr.workflows_have_shutdown(); - } else { - // Bump the workflow stream with a pointless input, since if a client initiates shutdown - // and then immediately blocks waiting on a workflow activation poll, it's possible that - // there may not be any more inputs ever, and that poll will never resolve. - self.workflows.send_get_state_info_msg(); } + // Bump the workflow stream with a pointless input, since if a client initiates shutdown + // and then immediately blocks waiting on a workflow activation poll, it's possible that + // there may not be any more inputs ever, and that poll will never resolve. + self.workflows.send_get_state_info_msg(); } async fn shutdown(&self) { @@ -580,7 +579,7 @@ impl Worker { deployment_options, ); let worker_instance_key = Uuid::new_v4(); - let worker_status = Arc::new(Mutex::new(WorkerStatus::Running)); + let worker_status = Arc::new(RwLock::new(WorkerStatus::Running)); let sdk_name_and_ver = client.sdk_name_and_version(); let worker_heartbeat = worker_heartbeat_interval.map(|hb_interval| { @@ -1056,7 +1055,7 @@ struct HeartbeatMetrics { wf_sticky_last_suc_poll_time: Arc>>, act_last_suc_poll_time: Arc>>, nexus_last_suc_poll_time: Arc>>, - status: Arc>, + status: Arc>, sys_info: Arc, } @@ -1102,7 +1101,7 @@ impl WorkerHeartbeatManager { task_queue: config.task_queue.clone(), deployment_version, - status: (*heartbeat_manager_metrics.status.lock()) as i32, + status: (*heartbeat_manager_metrics.status.read()) as i32, start_time, plugins: config.plugins.clone(), diff --git a/crates/sdk-core/src/worker/workflow/mod.rs b/crates/sdk-core/src/worker/workflow/mod.rs index abddc6593..78bc22253 100644 --- a/crates/sdk-core/src/worker/workflow/mod.rs +++ b/crates/sdk-core/src/worker/workflow/mod.rs @@ -202,13 +202,6 @@ impl Workflows { .unwrap(); let local = LocalSet::new(); local.block_on(&rt, async move { - let mut stream = WFStream::build( - basics, - extracted_wft_stream, - locals_stream, - local_activity_request_sink, - ); - // However, we want to avoid plowing ahead until we've been asked to poll at // least once. This supports activity-only workers. let do_poll = tokio::select! { @@ -222,6 +215,14 @@ impl Workflows { if !do_poll { return; } + + let mut stream = WFStream::build( + basics, + extracted_wft_stream, + locals_stream, + local_activity_request_sink, + ); + while let Some(output) = stream.next().await { match output { Ok(o) => {