Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ pub struct NoRetryOnMatching {
pub predicate: fn(&tonic::Status) -> bool,
}

/// A request extension that forces overriding the current retry policy of the [RetryClient].
#[derive(Clone, Debug)]
pub struct RetryConfigForCall(pub RetryConfig);

impl Debug for ClientTlsConfig {
// Intentionally omit details here since they could leak a key if ever printed
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Expand Down
9 changes: 7 additions & 2 deletions crates/client/src/retry.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{
ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT, IsWorkerTaskLongPoll, MESSAGE_TOO_LARGE_KEY,
NamespacedClient, NoRetryOnMatching, Result, RetryConfig, raw::IsUserLongPoll,
NamespacedClient, NoRetryOnMatching, Result, RetryConfig, RetryConfigForCall,
raw::IsUserLongPoll,
};
use backoff::{Clock, SystemClock, backoff::Backoff, exponential::ExponentialBackoff};
use futures_retry::{ErrorHandler, FutureRetry, RetryPolicy};
Expand Down Expand Up @@ -60,6 +61,7 @@ impl<SG> RetryClient<SG> {
) -> CallInfo {
let mut call_type = CallType::Normal;
let mut retry_short_circuit = None;
let mut retry_cfg_override = None;
if let Some(r) = request.as_ref() {
let ext = r.extensions();
if ext.get::<IsUserLongPoll>().is_some() {
Expand All @@ -69,8 +71,11 @@ impl<SG> RetryClient<SG> {
}

retry_short_circuit = ext.get::<NoRetryOnMatching>().cloned();
retry_cfg_override = ext.get::<RetryConfigForCall>().cloned();
}
let retry_cfg = if call_type == CallType::TaskLongPoll {
let retry_cfg = if let Some(ovr) = retry_cfg_override {
ovr.0
} else if call_type == CallType::TaskLongPoll {
RetryConfig::task_poll_retry_policy()
} else {
(*self.retry_config).clone()
Expand Down
12 changes: 8 additions & 4 deletions crates/sdk-core/src/worker/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{
};
use temporalio_client::{
Client, ClientWorkerSet, IsWorkerTaskLongPoll, Namespace, NamespacedClient, NoRetryOnMatching,
RetryClient, SharedReplaceableClient, WorkflowService,
RetryClient, RetryConfig, RetryConfigForCall, SharedReplaceableClient, WorkflowService,
};
use temporalio_common::{
protos::{
Expand Down Expand Up @@ -697,16 +697,20 @@ impl WorkerClient for WorkerClientBag {
w.status = WorkerStatus::Shutdown.into();
self.set_heartbeat_client_fields(w);
}
let request = ShutdownWorkerRequest {
let mut request = ShutdownWorkerRequest {
namespace: self.namespace.clone(),
identity: self.identity.clone(),
sticky_task_queue,
reason: "graceful shutdown".to_string(),
worker_heartbeat: final_heartbeat,
};
}
.into_request();
request
.extensions_mut()
.insert(RetryConfigForCall(RetryConfig::no_retries()));

Ok(
WorkflowService::shutdown_worker(&mut self.client.clone(), request.into_request())
WorkflowService::shutdown_worker(&mut self.client.clone(), request)
.await?
.into_inner(),
)
Expand Down
18 changes: 11 additions & 7 deletions crates/sdk-core/src/worker/heartbeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
WorkerClient,
worker::{TaskPollers, WorkerTelemetry},
};
use parking_lot::Mutex;
use parking_lot::RwLock;
use std::{collections::HashMap, sync::Arc, time::Duration};
use temporalio_client::SharedNamespaceWorkerTrait;
use temporalio_common::{
Expand All @@ -20,7 +20,7 @@ pub(crate) type HeartbeatFn = Arc<dyn Fn() -> WorkerHeartbeat + Send + Sync>;
/// worker heartbeats to the server. This invokes callbacks on all workers in the same process that
/// share the same namespace.
pub(crate) struct SharedNamespaceWorker {
heartbeat_map: Arc<Mutex<HashMap<Uuid, HeartbeatFn>>>,
heartbeat_map: Arc<RwLock<HashMap<Uuid, HeartbeatFn>>>,
namespace: String,
cancel: CancellationToken,
}
Expand Down Expand Up @@ -63,7 +63,7 @@ impl SharedNamespaceWorker {
let client_clone = client;
let namespace_clone = namespace.clone();

let heartbeat_map = Arc::new(Mutex::new(HashMap::<Uuid, HeartbeatFn>::new()));
let heartbeat_map = Arc::new(RwLock::new(HashMap::<Uuid, HeartbeatFn>::new()));
let heartbeat_map_clone = heartbeat_map.clone();

tokio::spawn(async move {
Expand Down Expand Up @@ -93,7 +93,10 @@ impl SharedNamespaceWorker {
tokio::select! {
_ = ticker.tick() => {
let mut hb_to_send = Vec::new();
for (_instance_key, heartbeat_callback) in heartbeat_map_clone.lock().iter() {
let hb_callbacks = {
heartbeat_map_clone.read().values().cloned().collect::<Vec<_>>()
};
for heartbeat_callback in hb_callbacks {
Comment on lines +96 to +99
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just some bonus cleanup to avoid holding this lock while doing the callback gathering

let mut heartbeat = heartbeat_callback();
// All of these heartbeat details rely on a client. To avoid circular
// dependencies, this must be populated from within SharedNamespaceWorker
Expand Down Expand Up @@ -135,11 +138,12 @@ impl SharedNamespaceWorkerTrait for SharedNamespaceWorker {

fn register_callback(&self, worker_instance_key: Uuid, heartbeat_callback: HeartbeatFn) {
self.heartbeat_map
.lock()
.write()
.insert(worker_instance_key, heartbeat_callback);
}

fn unregister_callback(&self, worker_instance_key: Uuid) -> (Option<HeartbeatFn>, bool) {
let mut heartbeat_map = self.heartbeat_map.lock();
let mut heartbeat_map = self.heartbeat_map.write();
let heartbeat_callback = heartbeat_map.remove(&worker_instance_key);
if heartbeat_map.is_empty() {
self.cancel.cancel();
Expand All @@ -148,7 +152,7 @@ impl SharedNamespaceWorkerTrait for SharedNamespaceWorker {
}

fn num_workers(&self) -> usize {
self.heartbeat_map.lock().len()
self.heartbeat_map.read().len()
}
}

Expand Down
33 changes: 20 additions & 13 deletions crates/sdk-core/src/worker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use anyhow::bail;
use crossbeam_utils::atomic::AtomicCell;
use futures_util::{StreamExt, stream};
use gethostname::gethostname;
use parking_lot::{Mutex, RwLock};
use parking_lot::RwLock;
use slot_provider::SlotProvider;
use std::{
convert::TryInto,
Expand Down Expand Up @@ -137,7 +137,7 @@ pub struct Worker {
/// Used to track worker client
client_worker_registrator: Arc<ClientWorkerRegistrator>,
/// Status of the worker
status: Arc<Mutex<WorkerStatus>>,
status: Arc<RwLock<WorkerStatus>>,
}

struct AllPermitsTracker {
Expand Down Expand Up @@ -256,7 +256,7 @@ impl WorkerTrait for Worker {
}
self.shutdown_token.cancel();
{
*self.status.lock() = WorkerStatus::ShuttingDown;
*self.status.write() = WorkerStatus::ShuttingDown;
}
// First, unregister worker from the client
if !self.client_worker_registrator.shared_namespace_worker {
Expand All @@ -276,12 +276,11 @@ impl WorkerTrait for Worker {

if !self.workflows.ever_polled() {
self.local_act_mgr.workflows_have_shutdown();
} else {
// Bump the workflow stream with a pointless input, since if a client initiates shutdown
// and then immediately blocks waiting on a workflow activation poll, it's possible that
// there may not be any more inputs ever, and that poll will never resolve.
self.workflows.send_get_state_info_msg();
}
// Bump the workflow stream with a pointless input, since if a client initiates shutdown
// and then immediately blocks waiting on a workflow activation poll, it's possible that
// there may not be any more inputs ever, and that poll will never resolve.
self.workflows.send_get_state_info_msg();
Comment on lines +280 to +283
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the real fix for the shutdown hang. There was a race where "ever polled" could be set true, but nothing else ever ended up in the stream, and driving the stream with bonus messages doesn't happen until we get to shutdown itself, but that won't happen (at least in some test setups, but maybe some actual SDK loops too) until pollers have returned.

}

async fn shutdown(&self) {
Expand Down Expand Up @@ -361,7 +360,12 @@ impl Worker {

#[cfg(test)]
pub(crate) fn new_test(config: WorkerConfig, client: impl WorkerClient + 'static) -> Self {
Self::new(config, None, Arc::new(client), None, None).unwrap()
let sticky_queue_name = if config.max_cached_workflows > 0 {
Some(format!("sticky-{}", config.task_queue))
} else {
None
};
Self::new(config, sticky_queue_name, Arc::new(client), None, None).unwrap()
}

pub(crate) fn new_with_pollers(
Expand Down Expand Up @@ -575,7 +579,7 @@ impl Worker {
deployment_options,
);
let worker_instance_key = Uuid::new_v4();
let worker_status = Arc::new(Mutex::new(WorkerStatus::Running));
let worker_status = Arc::new(RwLock::new(WorkerStatus::Running));

let sdk_name_and_ver = client.sdk_name_and_version();
let worker_heartbeat = worker_heartbeat_interval.map(|hb_interval| {
Expand Down Expand Up @@ -698,7 +702,10 @@ impl Worker {
tonic::Code::Unimplemented | tonic::Code::Unavailable
) =>
{
warn!("Failed to shutdown sticky queue {:?}", err);
warn!(
"shutdown_worker rpc errored during worker shutdown: {:?}",
err
);
}
_ => {}
}
Expand Down Expand Up @@ -1048,7 +1055,7 @@ struct HeartbeatMetrics {
wf_sticky_last_suc_poll_time: Arc<AtomicCell<Option<SystemTime>>>,
act_last_suc_poll_time: Arc<AtomicCell<Option<SystemTime>>>,
nexus_last_suc_poll_time: Arc<AtomicCell<Option<SystemTime>>>,
status: Arc<Mutex<WorkerStatus>>,
status: Arc<RwLock<WorkerStatus>>,
sys_info: Arc<dyn SystemResourceInfo + Send + Sync>,
}

Expand Down Expand Up @@ -1094,7 +1101,7 @@ impl WorkerHeartbeatManager {
task_queue: config.task_queue.clone(),
deployment_version,

status: (*heartbeat_manager_metrics.status.lock()) as i32,
status: (*heartbeat_manager_metrics.status.read()) as i32,
start_time,
plugins: config.plugins.clone(),

Expand Down
15 changes: 8 additions & 7 deletions crates/sdk-core/src/worker/workflow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,6 @@ impl Workflows {
.unwrap();
let local = LocalSet::new();
local.block_on(&rt, async move {
let mut stream = WFStream::build(
basics,
extracted_wft_stream,
locals_stream,
local_activity_request_sink,
);

// However, we want to avoid plowing ahead until we've been asked to poll at
// least once. This supports activity-only workers.
let do_poll = tokio::select! {
Expand All @@ -222,6 +215,14 @@ impl Workflows {
if !do_poll {
return;
}

let mut stream = WFStream::build(
basics,
extracted_wft_stream,
locals_stream,
local_activity_request_sink,
);

while let Some(output) = stream.next().await {
match output {
Ok(o) => {
Expand Down
103 changes: 103 additions & 0 deletions crates/sdk-core/tests/common/fake_grpc_server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use futures_util::future::{BoxFuture, FutureExt};
use std::{
convert::Infallible,
task::{Context, Poll},
};
use tokio::{
net::TcpListener,
sync::{mpsc::UnboundedSender, oneshot},
};
use tonic::{
body::Body, codegen::Service, codegen::http::Response, server::NamedService, transport::Server,
};

#[derive(Clone)]
pub(crate) struct GenericService<F> {
pub header_to_parse: &'static str,
pub header_tx: UnboundedSender<String>,
pub response_maker: F,
}
impl<F> Service<tonic::codegen::http::Request<Body>> for GenericService<F>
where
F: FnMut(tonic::codegen::http::Request<Body>) -> BoxFuture<'static, Response<Body>>,
{
type Response = Response<Body>;
type Error = Infallible;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, req: tonic::codegen::http::Request<Body>) -> Self::Future {
self.header_tx
.send(
String::from_utf8_lossy(
req.headers()
.get(self.header_to_parse)
.map(|hv| hv.as_bytes())
.unwrap_or_default(),
)
.to_string(),
)
.unwrap();
let r = (self.response_maker)(req);
async move { Ok(r.await) }.boxed()
}
}
impl<F> NamedService for GenericService<F> {
const NAME: &'static str = "temporal.api.workflowservice.v1.WorkflowService";
}

pub(crate) struct FakeServer {
pub addr: std::net::SocketAddr,
shutdown_tx: oneshot::Sender<()>,
pub header_rx: tokio::sync::mpsc::UnboundedReceiver<String>,
pub server_handle: tokio::task::JoinHandle<()>,
}

pub(crate) async fn fake_server<F>(response_maker: F) -> FakeServer
where
F: FnMut(tonic::codegen::http::Request<Body>) -> BoxFuture<'static, Response<Body>>
+ Clone
+ Send
+ Sync
+ 'static,
{
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let (header_tx, header_rx) = tokio::sync::mpsc::unbounded_channel();

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

let server_handle = tokio::spawn(async move {
Server::builder()
.add_service(GenericService {
header_to_parse: "grpc-timeout",
header_tx,
response_maker,
})
.serve_with_incoming_shutdown(
tokio_stream::wrappers::TcpListenerStream::new(listener),
async {
shutdown_rx.await.ok();
},
)
.await
.unwrap();
});

FakeServer {
addr,
shutdown_tx,
header_rx,
server_handle,
}
}

impl FakeServer {
pub(crate) async fn shutdown(self) {
self.shutdown_tx.send(()).unwrap();
self.server_handle.await.unwrap();
}
}
Loading
Loading