Skip to content

Commit 9a87ebf

Browse files
authored
No retry worker shutdown & fix shutdown hang (#1054)
1 parent 51c120c commit 9a87ebf

File tree

10 files changed

+210
-145
lines changed

10 files changed

+210
-145
lines changed

crates/client/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ pub struct NoRetryOnMatching {
329329
pub predicate: fn(&tonic::Status) -> bool,
330330
}
331331

332+
/// A request extension that forces overriding the current retry policy of the [RetryClient].
333+
#[derive(Clone, Debug)]
334+
pub struct RetryConfigForCall(pub RetryConfig);
335+
332336
impl Debug for ClientTlsConfig {
333337
// Intentionally omit details here since they could leak a key if ever printed
334338
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {

crates/client/src/retry.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::{
22
ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT, IsWorkerTaskLongPoll, MESSAGE_TOO_LARGE_KEY,
3-
NamespacedClient, NoRetryOnMatching, Result, RetryConfig, raw::IsUserLongPoll,
3+
NamespacedClient, NoRetryOnMatching, Result, RetryConfig, RetryConfigForCall,
4+
raw::IsUserLongPoll,
45
};
56
use backoff::{Clock, SystemClock, backoff::Backoff, exponential::ExponentialBackoff};
67
use futures_retry::{ErrorHandler, FutureRetry, RetryPolicy};
@@ -60,6 +61,7 @@ impl<SG> RetryClient<SG> {
6061
) -> CallInfo {
6162
let mut call_type = CallType::Normal;
6263
let mut retry_short_circuit = None;
64+
let mut retry_cfg_override = None;
6365
if let Some(r) = request.as_ref() {
6466
let ext = r.extensions();
6567
if ext.get::<IsUserLongPoll>().is_some() {
@@ -69,8 +71,11 @@ impl<SG> RetryClient<SG> {
6971
}
7072

7173
retry_short_circuit = ext.get::<NoRetryOnMatching>().cloned();
74+
retry_cfg_override = ext.get::<RetryConfigForCall>().cloned();
7275
}
73-
let retry_cfg = if call_type == CallType::TaskLongPoll {
76+
let retry_cfg = if let Some(ovr) = retry_cfg_override {
77+
ovr.0
78+
} else if call_type == CallType::TaskLongPoll {
7479
RetryConfig::task_poll_retry_policy()
7580
} else {
7681
(*self.retry_config).clone()

crates/sdk-core/src/worker/client.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use std::{
1111
};
1212
use temporalio_client::{
1313
Client, ClientWorkerSet, IsWorkerTaskLongPoll, Namespace, NamespacedClient, NoRetryOnMatching,
14-
RetryClient, SharedReplaceableClient, WorkflowService,
14+
RetryClient, RetryConfig, RetryConfigForCall, SharedReplaceableClient, WorkflowService,
1515
};
1616
use temporalio_common::{
1717
protos::{
@@ -697,16 +697,20 @@ impl WorkerClient for WorkerClientBag {
697697
w.status = WorkerStatus::Shutdown.into();
698698
self.set_heartbeat_client_fields(w);
699699
}
700-
let request = ShutdownWorkerRequest {
700+
let mut request = ShutdownWorkerRequest {
701701
namespace: self.namespace.clone(),
702702
identity: self.identity.clone(),
703703
sticky_task_queue,
704704
reason: "graceful shutdown".to_string(),
705705
worker_heartbeat: final_heartbeat,
706-
};
706+
}
707+
.into_request();
708+
request
709+
.extensions_mut()
710+
.insert(RetryConfigForCall(RetryConfig::no_retries()));
707711

708712
Ok(
709-
WorkflowService::shutdown_worker(&mut self.client.clone(), request.into_request())
713+
WorkflowService::shutdown_worker(&mut self.client.clone(), request)
710714
.await?
711715
.into_inner(),
712716
)

crates/sdk-core/src/worker/heartbeat.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::{
22
WorkerClient,
33
worker::{TaskPollers, WorkerTelemetry},
44
};
5-
use parking_lot::Mutex;
5+
use parking_lot::RwLock;
66
use std::{collections::HashMap, sync::Arc, time::Duration};
77
use temporalio_client::SharedNamespaceWorkerTrait;
88
use temporalio_common::{
@@ -20,7 +20,7 @@ pub(crate) type HeartbeatFn = Arc<dyn Fn() -> WorkerHeartbeat + Send + Sync>;
2020
/// worker heartbeats to the server. This invokes callbacks on all workers in the same process that
2121
/// share the same namespace.
2222
pub(crate) struct SharedNamespaceWorker {
23-
heartbeat_map: Arc<Mutex<HashMap<Uuid, HeartbeatFn>>>,
23+
heartbeat_map: Arc<RwLock<HashMap<Uuid, HeartbeatFn>>>,
2424
namespace: String,
2525
cancel: CancellationToken,
2626
}
@@ -63,7 +63,7 @@ impl SharedNamespaceWorker {
6363
let client_clone = client;
6464
let namespace_clone = namespace.clone();
6565

66-
let heartbeat_map = Arc::new(Mutex::new(HashMap::<Uuid, HeartbeatFn>::new()));
66+
let heartbeat_map = Arc::new(RwLock::new(HashMap::<Uuid, HeartbeatFn>::new()));
6767
let heartbeat_map_clone = heartbeat_map.clone();
6868

6969
tokio::spawn(async move {
@@ -93,7 +93,10 @@ impl SharedNamespaceWorker {
9393
tokio::select! {
9494
_ = ticker.tick() => {
9595
let mut hb_to_send = Vec::new();
96-
for (_instance_key, heartbeat_callback) in heartbeat_map_clone.lock().iter() {
96+
let hb_callbacks = {
97+
heartbeat_map_clone.read().values().cloned().collect::<Vec<_>>()
98+
};
99+
for heartbeat_callback in hb_callbacks {
97100
let mut heartbeat = heartbeat_callback();
98101
// All of these heartbeat details rely on a client. To avoid circular
99102
// dependencies, this must be populated from within SharedNamespaceWorker
@@ -135,11 +138,12 @@ impl SharedNamespaceWorkerTrait for SharedNamespaceWorker {
135138

136139
fn register_callback(&self, worker_instance_key: Uuid, heartbeat_callback: HeartbeatFn) {
137140
self.heartbeat_map
138-
.lock()
141+
.write()
139142
.insert(worker_instance_key, heartbeat_callback);
140143
}
144+
141145
fn unregister_callback(&self, worker_instance_key: Uuid) -> (Option<HeartbeatFn>, bool) {
142-
let mut heartbeat_map = self.heartbeat_map.lock();
146+
let mut heartbeat_map = self.heartbeat_map.write();
143147
let heartbeat_callback = heartbeat_map.remove(&worker_instance_key);
144148
if heartbeat_map.is_empty() {
145149
self.cancel.cancel();
@@ -148,7 +152,7 @@ impl SharedNamespaceWorkerTrait for SharedNamespaceWorker {
148152
}
149153

150154
fn num_workers(&self) -> usize {
151-
self.heartbeat_map.lock().len()
155+
self.heartbeat_map.read().len()
152156
}
153157
}
154158

crates/sdk-core/src/worker/mod.rs

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ use anyhow::bail;
5353
use crossbeam_utils::atomic::AtomicCell;
5454
use futures_util::{StreamExt, stream};
5555
use gethostname::gethostname;
56-
use parking_lot::{Mutex, RwLock};
56+
use parking_lot::RwLock;
5757
use slot_provider::SlotProvider;
5858
use std::{
5959
convert::TryInto,
@@ -137,7 +137,7 @@ pub struct Worker {
137137
/// Used to track worker client
138138
client_worker_registrator: Arc<ClientWorkerRegistrator>,
139139
/// Status of the worker
140-
status: Arc<Mutex<WorkerStatus>>,
140+
status: Arc<RwLock<WorkerStatus>>,
141141
}
142142

143143
struct AllPermitsTracker {
@@ -256,7 +256,7 @@ impl WorkerTrait for Worker {
256256
}
257257
self.shutdown_token.cancel();
258258
{
259-
*self.status.lock() = WorkerStatus::ShuttingDown;
259+
*self.status.write() = WorkerStatus::ShuttingDown;
260260
}
261261
// First, unregister worker from the client
262262
if !self.client_worker_registrator.shared_namespace_worker {
@@ -276,12 +276,11 @@ impl WorkerTrait for Worker {
276276

277277
if !self.workflows.ever_polled() {
278278
self.local_act_mgr.workflows_have_shutdown();
279-
} else {
280-
// Bump the workflow stream with a pointless input, since if a client initiates shutdown
281-
// and then immediately blocks waiting on a workflow activation poll, it's possible that
282-
// there may not be any more inputs ever, and that poll will never resolve.
283-
self.workflows.send_get_state_info_msg();
284279
}
280+
// Bump the workflow stream with a pointless input, since if a client initiates shutdown
281+
// and then immediately blocks waiting on a workflow activation poll, it's possible that
282+
// there may not be any more inputs ever, and that poll will never resolve.
283+
self.workflows.send_get_state_info_msg();
285284
}
286285

287286
async fn shutdown(&self) {
@@ -361,7 +360,12 @@ impl Worker {
361360

362361
#[cfg(test)]
363362
pub(crate) fn new_test(config: WorkerConfig, client: impl WorkerClient + 'static) -> Self {
364-
Self::new(config, None, Arc::new(client), None, None).unwrap()
363+
let sticky_queue_name = if config.max_cached_workflows > 0 {
364+
Some(format!("sticky-{}", config.task_queue))
365+
} else {
366+
None
367+
};
368+
Self::new(config, sticky_queue_name, Arc::new(client), None, None).unwrap()
365369
}
366370

367371
pub(crate) fn new_with_pollers(
@@ -575,7 +579,7 @@ impl Worker {
575579
deployment_options,
576580
);
577581
let worker_instance_key = Uuid::new_v4();
578-
let worker_status = Arc::new(Mutex::new(WorkerStatus::Running));
582+
let worker_status = Arc::new(RwLock::new(WorkerStatus::Running));
579583

580584
let sdk_name_and_ver = client.sdk_name_and_version();
581585
let worker_heartbeat = worker_heartbeat_interval.map(|hb_interval| {
@@ -698,7 +702,10 @@ impl Worker {
698702
tonic::Code::Unimplemented | tonic::Code::Unavailable
699703
) =>
700704
{
701-
warn!("Failed to shutdown sticky queue {:?}", err);
705+
warn!(
706+
"shutdown_worker rpc errored during worker shutdown: {:?}",
707+
err
708+
);
702709
}
703710
_ => {}
704711
}
@@ -1048,7 +1055,7 @@ struct HeartbeatMetrics {
10481055
wf_sticky_last_suc_poll_time: Arc<AtomicCell<Option<SystemTime>>>,
10491056
act_last_suc_poll_time: Arc<AtomicCell<Option<SystemTime>>>,
10501057
nexus_last_suc_poll_time: Arc<AtomicCell<Option<SystemTime>>>,
1051-
status: Arc<Mutex<WorkerStatus>>,
1058+
status: Arc<RwLock<WorkerStatus>>,
10521059
sys_info: Arc<dyn SystemResourceInfo + Send + Sync>,
10531060
}
10541061

@@ -1094,7 +1101,7 @@ impl WorkerHeartbeatManager {
10941101
task_queue: config.task_queue.clone(),
10951102
deployment_version,
10961103

1097-
status: (*heartbeat_manager_metrics.status.lock()) as i32,
1104+
status: (*heartbeat_manager_metrics.status.read()) as i32,
10981105
start_time,
10991106
plugins: config.plugins.clone(),
11001107

crates/sdk-core/src/worker/workflow/mod.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,6 @@ impl Workflows {
202202
.unwrap();
203203
let local = LocalSet::new();
204204
local.block_on(&rt, async move {
205-
let mut stream = WFStream::build(
206-
basics,
207-
extracted_wft_stream,
208-
locals_stream,
209-
local_activity_request_sink,
210-
);
211-
212205
// However, we want to avoid plowing ahead until we've been asked to poll at
213206
// least once. This supports activity-only workers.
214207
let do_poll = tokio::select! {
@@ -222,6 +215,14 @@ impl Workflows {
222215
if !do_poll {
223216
return;
224217
}
218+
219+
let mut stream = WFStream::build(
220+
basics,
221+
extracted_wft_stream,
222+
locals_stream,
223+
local_activity_request_sink,
224+
);
225+
225226
while let Some(output) = stream.next().await {
226227
match output {
227228
Ok(o) => {
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
use futures_util::future::{BoxFuture, FutureExt};
2+
use std::{
3+
convert::Infallible,
4+
task::{Context, Poll},
5+
};
6+
use tokio::{
7+
net::TcpListener,
8+
sync::{mpsc::UnboundedSender, oneshot},
9+
};
10+
use tonic::{
11+
body::Body, codegen::Service, codegen::http::Response, server::NamedService, transport::Server,
12+
};
13+
14+
#[derive(Clone)]
15+
pub(crate) struct GenericService<F> {
16+
pub header_to_parse: &'static str,
17+
pub header_tx: UnboundedSender<String>,
18+
pub response_maker: F,
19+
}
20+
impl<F> Service<tonic::codegen::http::Request<Body>> for GenericService<F>
21+
where
22+
F: FnMut(tonic::codegen::http::Request<Body>) -> BoxFuture<'static, Response<Body>>,
23+
{
24+
type Response = Response<Body>;
25+
type Error = Infallible;
26+
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
27+
28+
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
29+
Poll::Ready(Ok(()))
30+
}
31+
32+
fn call(&mut self, req: tonic::codegen::http::Request<Body>) -> Self::Future {
33+
self.header_tx
34+
.send(
35+
String::from_utf8_lossy(
36+
req.headers()
37+
.get(self.header_to_parse)
38+
.map(|hv| hv.as_bytes())
39+
.unwrap_or_default(),
40+
)
41+
.to_string(),
42+
)
43+
.unwrap();
44+
let r = (self.response_maker)(req);
45+
async move { Ok(r.await) }.boxed()
46+
}
47+
}
48+
impl<F> NamedService for GenericService<F> {
49+
const NAME: &'static str = "temporal.api.workflowservice.v1.WorkflowService";
50+
}
51+
52+
pub(crate) struct FakeServer {
53+
pub addr: std::net::SocketAddr,
54+
shutdown_tx: oneshot::Sender<()>,
55+
pub header_rx: tokio::sync::mpsc::UnboundedReceiver<String>,
56+
pub server_handle: tokio::task::JoinHandle<()>,
57+
}
58+
59+
pub(crate) async fn fake_server<F>(response_maker: F) -> FakeServer
60+
where
61+
F: FnMut(tonic::codegen::http::Request<Body>) -> BoxFuture<'static, Response<Body>>
62+
+ Clone
63+
+ Send
64+
+ Sync
65+
+ 'static,
66+
{
67+
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
68+
let (header_tx, header_rx) = tokio::sync::mpsc::unbounded_channel();
69+
70+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
71+
let addr = listener.local_addr().unwrap();
72+
73+
let server_handle = tokio::spawn(async move {
74+
Server::builder()
75+
.add_service(GenericService {
76+
header_to_parse: "grpc-timeout",
77+
header_tx,
78+
response_maker,
79+
})
80+
.serve_with_incoming_shutdown(
81+
tokio_stream::wrappers::TcpListenerStream::new(listener),
82+
async {
83+
shutdown_rx.await.ok();
84+
},
85+
)
86+
.await
87+
.unwrap();
88+
});
89+
90+
FakeServer {
91+
addr,
92+
shutdown_tx,
93+
header_rx,
94+
server_handle,
95+
}
96+
}
97+
98+
impl FakeServer {
99+
pub(crate) async fn shutdown(self) {
100+
self.shutdown_tx.send(()).unwrap();
101+
self.server_handle.await.unwrap();
102+
}
103+
}

0 commit comments

Comments
 (0)