Skip to content

Commit

Permalink
Fix case when scheduler drops action on client reconnect (#1198)
Browse files Browse the repository at this point in the history
Fixes a bug where if a client creates an action, then the
client disconnects and then reconnects on the same action
it would not keep the action alive and eventually time it
out.

closes: #1197
  • Loading branch information
allada authored Jul 26, 2024
1 parent 534a102 commit 0b40639
Show file tree
Hide file tree
Showing 13 changed files with 310 additions and 127 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion nativelink-config/src/stores.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,8 @@ pub struct EvictionPolicy {
#[serde(default, deserialize_with = "convert_data_size_with_shellexpand")]
pub evict_bytes: usize,

/// Maximum number of seconds for an entry to live before an eviction.
/// Maximum number of seconds for an entry to live since it was last
/// accessed before it is evicted.
/// Default: 0. Zero means never evict based on time.
#[serde(default, deserialize_with = "convert_duration_with_shellexpand")]
pub max_seconds: u32,
Expand Down
1 change: 1 addition & 0 deletions nativelink-scheduler/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ rust_test_suite(
"//nativelink-store",
"//nativelink-util",
"@crates//:futures",
"@crates//:mock_instant",
"@crates//:pretty_assertions",
"@crates//:prost",
"@crates//:tokio",
Expand Down
1 change: 1 addition & 0 deletions nativelink-scheduler/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ uuid = { version = "1.8.0", features = ["v4"] }
futures = "0.3.30"
hashbrown = "0.14"
lru = "0.12.3"
mock_instant = "0.3.2"
parking_lot = "0.12.2"
rand = "0.8.5"
scopeguard = "1.2.0"
Expand Down
146 changes: 92 additions & 54 deletions nativelink-scheduler/src/memory_awaited_action_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::ops::{Bound, RangeBounds};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use std::time::Duration;

use async_lock::Mutex;
use async_trait::async_trait;
Expand All @@ -28,6 +28,7 @@ use nativelink_util::action_messages::{
};
use nativelink_util::chunked_stream::ChunkedStream;
use nativelink_util::evicting_map::{EvictingMap, LenEntry};
use nativelink_util::instant_wrapper::InstantWrapper;
use nativelink_util::metrics_utils::{CollectorState, MetricsComponent};
use nativelink_util::operation_state_manager::ActionStateResult;
use nativelink_util::spawn;
Expand All @@ -48,21 +49,21 @@ const CLIENT_KEEPALIVE_DURATION: Duration = Duration::from_secs(10);

/// Represents a client that is currently listening to an action.
/// When the client is dropped, it will send the [`AwaitedAction`] to the
/// `drop_tx` if there are other cleanups needed.
/// `event_tx` if there are other cleanups needed.
#[derive(Debug)]
struct ClientAwaitedAction {
/// The OperationId that the client is listening to.
operation_id: OperationId,

/// The sender to notify of this struct being dropped.
drop_tx: mpsc::UnboundedSender<ActionEvent>,
event_tx: mpsc::UnboundedSender<ActionEvent>,
}

impl ClientAwaitedAction {
pub fn new(operation_id: OperationId, drop_tx: mpsc::UnboundedSender<ActionEvent>) -> Self {
pub fn new(operation_id: OperationId, event_tx: mpsc::UnboundedSender<ActionEvent>) -> Self {
Self {
operation_id,
drop_tx,
event_tx,
}
}

Expand All @@ -74,7 +75,7 @@ impl ClientAwaitedAction {
impl Drop for ClientAwaitedAction {
fn drop(&mut self) {
// If we failed to send it means noone is listening.
let _ = self.drop_tx.send(ActionEvent::ClientDroppedOperation(
let _ = self.event_tx.send(ActionEvent::ClientDroppedOperation(
self.operation_id.clone(),
));
}
Expand Down Expand Up @@ -105,50 +106,61 @@ pub(crate) enum ActionEvent {

/// Information required to track an individual client
/// keep alive config and state.
struct ClientKeepAlive {
struct ClientInfo<I: InstantWrapper, NowFn: Fn() -> I> {
/// The client operation id.
client_operation_id: ClientOperationId,
/// The last time a keep alive was sent.
last_keep_alive: Instant,
/// The sender to notify of this struct being dropped.
drop_tx: mpsc::UnboundedSender<ActionEvent>,
last_keep_alive: I,
/// The function to get the current time.
now_fn: NowFn,
/// The sender to notify of this struct had an event.
event_tx: mpsc::UnboundedSender<ActionEvent>,
}

/// Subscriber that can be used to monitor when AwaitedActions change.
pub struct MemoryAwaitedActionSubscriber {
/// Subscriber that clients can be used to monitor when AwaitedActions change.
pub struct MemoryAwaitedActionSubscriber<I: InstantWrapper, NowFn: Fn() -> I> {
/// The receiver to listen for changes.
awaited_action_rx: watch::Receiver<AwaitedAction>,
/// The client operation id and keep alive information.
client_operation_info: Option<ClientKeepAlive>,
/// If a client id is known this is the info needed to keep the client
/// action alive.
client_info: Option<ClientInfo<I, NowFn>>,
}

impl MemoryAwaitedActionSubscriber {
impl<I: InstantWrapper, NowFn: Fn() -> I> MemoryAwaitedActionSubscriber<I, NowFn> {
pub fn new(mut awaited_action_rx: watch::Receiver<AwaitedAction>) -> Self {
awaited_action_rx.mark_changed();
Self {
awaited_action_rx,
client_operation_info: None,
client_info: None,
}
}

pub fn new_with_client(
mut awaited_action_rx: watch::Receiver<AwaitedAction>,
client_operation_id: ClientOperationId,
drop_tx: mpsc::UnboundedSender<ActionEvent>,
) -> Self {
event_tx: mpsc::UnboundedSender<ActionEvent>,
now_fn: NowFn,
) -> Self
where
NowFn: Fn() -> I,
{
awaited_action_rx.mark_changed();
Self {
awaited_action_rx,
client_operation_info: Some(ClientKeepAlive {
client_info: Some(ClientInfo {
client_operation_id,
last_keep_alive: Instant::now(),
drop_tx,
last_keep_alive: I::from_secs(0),
now_fn,
event_tx,
}),
}
}
}

impl AwaitedActionSubscriber for MemoryAwaitedActionSubscriber {
impl<I, NowFn> AwaitedActionSubscriber for MemoryAwaitedActionSubscriber<I, NowFn>
where
I: InstantWrapper,
NowFn: Fn() -> I + Send + Sync + 'static,
{
async fn changed(&mut self) -> Result<AwaitedAction, Error> {
{
let changed_fut = self.awaited_action_rx.changed().map(|r| {
Expand All @@ -159,25 +171,26 @@ impl AwaitedActionSubscriber for MemoryAwaitedActionSubscriber {
)
})
});
let Some(client_keep_alive) = self.client_operation_info.as_mut() else {
let Some(client_info) = self.client_info.as_mut() else {
changed_fut.await?;
return Ok(self.awaited_action_rx.borrow().clone());
};
tokio::pin!(changed_fut);
loop {
if client_keep_alive.last_keep_alive.elapsed() > CLIENT_KEEPALIVE_DURATION {
client_keep_alive.last_keep_alive = Instant::now();
if client_info.last_keep_alive.elapsed() > CLIENT_KEEPALIVE_DURATION {
client_info.last_keep_alive = (client_info.now_fn)();
// Failing to send just means our receiver dropped.
let _ = client_keep_alive.drop_tx.send(ActionEvent::ClientKeepAlive(
client_keep_alive.client_operation_id.clone(),
let _ = client_info.event_tx.send(ActionEvent::ClientKeepAlive(
client_info.client_operation_id.clone(),
));
}
let sleep_fut = (client_info.now_fn)().sleep(CLIENT_KEEPALIVE_DURATION);
tokio::select! {
result = &mut changed_fut => {
result?;
break;
}
_ = tokio::time::sleep(CLIENT_KEEPALIVE_DURATION) => {
_ = sleep_fut => {
// If we haven't received any updates for a while, we should
// let the database know that we are still listening to prevent
// the action from being dropped.
Expand Down Expand Up @@ -329,10 +342,9 @@ impl SortedAwaitedActions {
}

/// The database for storing the state of all actions.
pub struct AwaitedActionDbImpl {
pub struct AwaitedActionDbImpl<I: InstantWrapper, NowFn: Fn() -> I> {
/// A lookup table to lookup the state of an action by its client operation id.
client_operation_to_awaited_action:
EvictingMap<ClientOperationId, Arc<ClientAwaitedAction>, SystemTime>,
client_operation_to_awaited_action: EvictingMap<ClientOperationId, Arc<ClientAwaitedAction>, I>,

/// A lookup table to lookup the state of an action by its worker operation id.
operation_id_to_awaited_action: BTreeMap<OperationId, watch::Sender<AwaitedAction>>,
Expand All @@ -351,13 +363,16 @@ pub struct AwaitedActionDbImpl {

/// Where to send notifications about important events related to actions.
action_event_tx: mpsc::UnboundedSender<ActionEvent>,

/// The function to get the current time.
now_fn: NowFn,
}

impl AwaitedActionDbImpl {
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync> AwaitedActionDbImpl<I, NowFn> {
async fn get_awaited_action_by_id(
&self,
client_operation_id: &ClientOperationId,
) -> Result<Option<MemoryAwaitedActionSubscriber>, Error> {
) -> Result<Option<MemoryAwaitedActionSubscriber<I, NowFn>>, Error> {
let maybe_client_awaited_action = self
.client_operation_to_awaited_action
.get(client_operation_id)
Expand All @@ -369,7 +384,14 @@ impl AwaitedActionDbImpl {

self.operation_id_to_awaited_action
.get(client_awaited_action.operation_id())
.map(|tx| Some(MemoryAwaitedActionSubscriber::new(tx.subscribe())))
.map(|tx| {
Some(MemoryAwaitedActionSubscriber::new_with_client(
tx.subscribe(),
client_operation_id.clone(),
self.action_event_tx.clone(),
self.now_fn.clone(),
))
})
.ok_or_else(|| {
make_err!(
Code::Internal,
Expand Down Expand Up @@ -487,32 +509,38 @@ impl AwaitedActionDbImpl {
&self,
start: Bound<&OperationId>,
end: Bound<&OperationId>,
) -> impl Iterator<Item = (&'_ OperationId, MemoryAwaitedActionSubscriber)> {
) -> impl Iterator<Item = (&'_ OperationId, MemoryAwaitedActionSubscriber<I, NowFn>)> {
self.operation_id_to_awaited_action
.range((start, end))
.map(|(operation_id, tx)| {
(
operation_id,
MemoryAwaitedActionSubscriber::new(tx.subscribe()),
MemoryAwaitedActionSubscriber::<I, NowFn>::new(tx.subscribe()),
)
})
}

fn get_by_operation_id(
&self,
operation_id: &OperationId,
) -> Option<MemoryAwaitedActionSubscriber> {
) -> Option<MemoryAwaitedActionSubscriber<I, NowFn>> {
self.operation_id_to_awaited_action
.get(operation_id)
.map(|tx| MemoryAwaitedActionSubscriber::new(tx.subscribe()))
.map(|tx| MemoryAwaitedActionSubscriber::<I, NowFn>::new(tx.subscribe()))
}

fn get_range_of_actions<'a, 'b>(
&'a self,
state: SortedAwaitedActionState,
range: impl RangeBounds<SortedAwaitedAction> + 'b,
) -> impl DoubleEndedIterator<
Item = Result<(&'a SortedAwaitedAction, MemoryAwaitedActionSubscriber), Error>,
Item = Result<
(
&'a SortedAwaitedAction,
MemoryAwaitedActionSubscriber<I, NowFn>,
),
Error,
>,
> + 'a {
let btree = match state {
SortedAwaitedActionState::CacheCheck => &self.sorted_action_info_hash_keys.cache_check,
Expand Down Expand Up @@ -674,7 +702,7 @@ impl AwaitedActionDbImpl {
&mut self,
client_operation_id: ClientOperationId,
action_info: Arc<ActionInfo>,
) -> Result<MemoryAwaitedActionSubscriber, Error> {
) -> Result<MemoryAwaitedActionSubscriber<I, NowFn>, Error> {
// Check to see if the action is already known and subscribe if it is.
let subscription_result = self
.try_subscribe(
Expand Down Expand Up @@ -738,6 +766,7 @@ impl AwaitedActionDbImpl {
rx,
client_operation_id,
self.action_event_tx.clone(),
self.now_fn.clone(),
))
}

Expand All @@ -749,7 +778,7 @@ impl AwaitedActionDbImpl {
// removed the ability to upgrade priorities of actions.
// we should add priority upgrades back in.
_priority: i32,
) -> Result<Option<MemoryAwaitedActionSubscriber>, Error> {
) -> Result<Option<MemoryAwaitedActionSubscriber<I, NowFn>>, Error> {
let unique_key = match unique_qualifier {
ActionUniqueQualifier::Cachable(unique_key) => unique_key,
ActionUniqueQualifier::Uncachable(_unique_key) => return Ok(None),
Expand Down Expand Up @@ -795,28 +824,33 @@ impl AwaitedActionDbImpl {
)
.await;

Ok(Some(MemoryAwaitedActionSubscriber::new(subscription)))
Ok(Some(MemoryAwaitedActionSubscriber::new_with_client(
subscription,
client_operation_id.clone(),
self.action_event_tx.clone(),
self.now_fn.clone(),
)))
}
}

pub struct MemoryAwaitedActionDb {
inner: Arc<Mutex<AwaitedActionDbImpl>>,
pub struct MemoryAwaitedActionDb<I: InstantWrapper, NowFn: Fn() -> I> {
inner: Arc<Mutex<AwaitedActionDbImpl<I, NowFn>>>,
_handle_awaited_action_events: JoinHandleDropGuard<()>,
}

impl MemoryAwaitedActionDb {
pub fn new(eviction_config: &EvictionPolicy) -> Self {
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync + 'static>
MemoryAwaitedActionDb<I, NowFn>
{
pub fn new(eviction_config: &EvictionPolicy, now_fn: NowFn) -> Self {
let (action_event_tx, mut action_event_rx) = mpsc::unbounded_channel();
let inner = Arc::new(Mutex::new(AwaitedActionDbImpl {
client_operation_to_awaited_action: EvictingMap::new(
eviction_config,
SystemTime::now(),
),
client_operation_to_awaited_action: EvictingMap::new(eviction_config, (now_fn)()),
operation_id_to_awaited_action: BTreeMap::new(),
action_info_hash_key_to_awaited_action: HashMap::new(),
sorted_action_info_hash_keys: SortedAwaitedActions::default(),
connected_clients_for_operation_id: HashMap::new(),
action_event_tx,
now_fn,
}));
let weak_inner = Arc::downgrade(&inner);
Self {
Expand All @@ -841,8 +875,10 @@ impl MemoryAwaitedActionDb {
}
}

impl AwaitedActionDb for MemoryAwaitedActionDb {
type Subscriber = MemoryAwaitedActionSubscriber;
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync + 'static> AwaitedActionDb
for MemoryAwaitedActionDb<I, NowFn>
{
type Subscriber = MemoryAwaitedActionSubscriber<I, NowFn>;

async fn get_awaited_action_by_id(
&self,
Expand Down Expand Up @@ -943,7 +979,9 @@ impl AwaitedActionDb for MemoryAwaitedActionDb {
}
}

impl MetricsComponent for MemoryAwaitedActionDb {
impl<I: InstantWrapper, NowFn: Fn() -> I + Send + Sync + 'static> MetricsComponent
for MemoryAwaitedActionDb<I, NowFn>
{
fn gather_metrics(&self, c: &mut CollectorState) {
let inner = self.inner.lock_blocking();
c.publish(
Expand Down
Loading

0 comments on commit 0b40639

Please sign in to comment.