Skip to content

Commit

Permalink
[Refactor] Simple scheduler method signatures to async (#971)
Browse files Browse the repository at this point in the history
Being able to hold a lock over async code requires the use of
`async_lock` over `parking_lot`, migrating `Mutex` & `MutexGuard`
to that package. In addition updating function signatures to
`async` for better fitting of API updates for new scheduler.
  • Loading branch information
adam-singer authored Jun 6, 2024
1 parent cec25fb commit 3c50dd5
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 37 deletions.
71 changes: 41 additions & 30 deletions nativelink-scheduler/src/simple_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Instant, SystemTime};

use async_lock::{Mutex, MutexGuard};
use async_trait::async_trait;
use futures::Future;
use hashbrown::{HashMap, HashSet};
Expand All @@ -37,7 +38,6 @@ use nativelink_util::metrics_utils::{
use nativelink_util::platform_properties::PlatformPropertyValue;
use nativelink_util::spawn;
use nativelink_util::task::JoinHandleDropGuard;
use parking_lot::{Mutex, MutexGuard};
use tokio::sync::{watch, Notify};
use tokio::time::Duration;
use tracing::{event, Level};
Expand Down Expand Up @@ -257,7 +257,7 @@ impl SimpleSchedulerImpl {
/// If the task cannot be executed immediately it will be queued for execution
/// based on priority and other metrics.
/// All further updates to the action will be provided through `listener`.
fn add_action(
async fn add_action(
&mut self,
action_info: ActionInfo,
) -> Result<watch::Receiver<Arc<ActionState>>, Error> {
Expand Down Expand Up @@ -341,7 +341,7 @@ impl SimpleSchedulerImpl {
.map(|action| watch::channel(action.state.clone()).1)
}

fn find_existing_action(
async fn find_existing_action(
&self,
unique_qualifier: &ActionInfoHashKey,
) -> Option<watch::Receiver<Arc<ActionState>>> {
Expand Down Expand Up @@ -444,7 +444,7 @@ impl SimpleSchedulerImpl {
// TODO(blaise.bruer) This is an O(n*m) (aka n^2) algorithm. In theory we can create a map
// of capabilities of each worker and then try and match the actions to the worker using
// the map lookup (ie. map reduce).
fn do_try_match(&mut self) {
async fn do_try_match(&mut self) {
// TODO(blaise.bruer) This is a bit difficult because of how rust's borrow checker gets in
// the way. We need to conditionally remove items from the `queued_action`. Rust is working
// to add `drain_filter`, which would in theory solve this problem, but because we need
Expand Down Expand Up @@ -598,7 +598,7 @@ impl SimpleSchedulerImpl {
self.tasks_or_workers_change_notify.notify_one();
}

fn update_action(
async fn update_action(
&mut self,
worker_id: &WorkerId,
action_info_hash_key: &ActionInfoHashKey,
Expand Down Expand Up @@ -781,9 +781,9 @@ impl SimpleScheduler {
// really need to worry about this thread taking the lock
// starving other threads too much.
Some(inner_mux) => {
let mut inner = inner_mux.lock();
let mut inner = inner_mux.lock().await;
let timer = metrics_for_do_try_match.do_try_match.begin_timer();
inner.do_try_match();
inner.do_try_match().await;
timer.measure();
}
// If the inner went away it means the scheduler is shutting
Expand All @@ -801,34 +801,40 @@ impl SimpleScheduler {

/// Checks to see if the worker exists in the worker pool. Should only be used in unit tests.
#[must_use]
pub fn contains_worker_for_test(&self, worker_id: &WorkerId) -> bool {
let inner = self.get_inner_lock();
pub async fn contains_worker_for_test(&self, worker_id: &WorkerId) -> bool {
let inner = self.get_inner_lock().await;
inner.workers.workers.contains(worker_id)
}

/// Checks to see if the worker can accept work. Should only be used in unit tests.
pub fn can_worker_accept_work_for_test(&self, worker_id: &WorkerId) -> Result<bool, Error> {
let mut inner = self.get_inner_lock();
pub async fn can_worker_accept_work_for_test(
&self,
worker_id: &WorkerId,
) -> Result<bool, Error> {
let mut inner = self.get_inner_lock().await;
let worker = inner.workers.workers.get_mut(worker_id).ok_or_else(|| {
make_input_err!("WorkerId '{}' does not exist in workers map", worker_id)
})?;
Ok(worker.can_accept_work())
}

/// A unit test function used to send the keep alive message to the worker from the server.
pub fn send_keep_alive_to_worker_for_test(&self, worker_id: &WorkerId) -> Result<(), Error> {
let mut inner = self.get_inner_lock();
pub async fn send_keep_alive_to_worker_for_test(
&self,
worker_id: &WorkerId,
) -> Result<(), Error> {
let mut inner = self.get_inner_lock().await;
let worker = inner.workers.workers.get_mut(worker_id).ok_or_else(|| {
make_input_err!("WorkerId '{}' does not exist in workers map", worker_id)
})?;
worker.keep_alive()
}

fn get_inner_lock(&self) -> MutexGuard<'_, SimpleSchedulerImpl> {
async fn get_inner_lock(&self) -> MutexGuard<'_, SimpleSchedulerImpl> {
// We don't use one of the wrappers because we only want to capture the time spent,
// nothing else beacuse this is a hot path.
let start = Instant::now();
let lock = self.inner.lock();
let lock: MutexGuard<SimpleSchedulerImpl> = self.inner.lock().await;
self.metrics
.lock_stall_time
.fetch_add(start.elapsed().as_nanos() as u64, Ordering::Relaxed);
Expand All @@ -852,19 +858,21 @@ impl ActionScheduler for SimpleScheduler {
&self,
action_info: ActionInfo,
) -> Result<watch::Receiver<Arc<ActionState>>, Error> {
let mut inner = self.get_inner_lock();
let mut inner = self.get_inner_lock().await;
self.metrics
.add_action
.wrap(move || inner.add_action(action_info))
.wrap(inner.add_action(action_info))
.await
}

async fn find_existing_action(
&self,
unique_qualifier: &ActionInfoHashKey,
) -> Option<watch::Receiver<Arc<ActionState>>> {
let inner = self.get_inner_lock();
let inner = self.get_inner_lock().await;
let result = inner
.find_existing_action(unique_qualifier)
.await
.or_else(|| inner.find_recently_completed_action(unique_qualifier));
if result.is_some() {
self.metrics.existing_actions_found.inc();
Expand All @@ -875,7 +883,9 @@ impl ActionScheduler for SimpleScheduler {
}

async fn clean_recently_completed_actions(&self) {
self.get_inner_lock().clean_recently_completed_actions();
self.get_inner_lock()
.await
.clean_recently_completed_actions();
self.metrics.clean_recently_completed_actions.inc()
}

Expand All @@ -892,7 +902,7 @@ impl WorkerScheduler for SimpleScheduler {

async fn add_worker(&self, worker: Worker) -> Result<(), Error> {
let worker_id = worker.id;
let mut inner = self.get_inner_lock();
let mut inner = self.get_inner_lock().await;
self.metrics.add_worker.wrap(move || {
let res = inner
.workers
Expand All @@ -912,7 +922,7 @@ impl WorkerScheduler for SimpleScheduler {
action_info_hash_key: &ActionInfoHashKey,
err: Error,
) {
let mut inner = self.get_inner_lock();
let mut inner = self.get_inner_lock().await;
inner.update_action_with_internal_error(worker_id, action_info_hash_key, err);
}

Expand All @@ -922,34 +932,35 @@ impl WorkerScheduler for SimpleScheduler {
action_info_hash_key: &ActionInfoHashKey,
action_stage: ActionStage,
) -> Result<(), Error> {
let mut inner = self.get_inner_lock();
let mut inner = self.get_inner_lock().await;
self.metrics
.update_action
.wrap(move || inner.update_action(worker_id, action_info_hash_key, action_stage))
.wrap(inner.update_action(worker_id, action_info_hash_key, action_stage))
.await
}

async fn worker_keep_alive_received(
&self,
worker_id: &WorkerId,
timestamp: WorkerTimestamp,
) -> Result<(), Error> {
let mut inner = self.get_inner_lock();
let mut inner = self.get_inner_lock().await;
inner
.workers
.refresh_lifetime(worker_id, timestamp)
.err_tip(|| "Error refreshing lifetime in worker_keep_alive_received()")
}

async fn remove_worker(&self, worker_id: WorkerId) {
let mut inner = self.get_inner_lock();
let mut inner = self.get_inner_lock().await;
inner.immediate_evict_worker(
&worker_id,
make_err!(Code::Internal, "Received request to remove worker"),
);
}

async fn remove_timedout_workers(&self, now_timestamp: WorkerTimestamp) -> Result<(), Error> {
let mut inner = self.get_inner_lock();
let mut inner = self.get_inner_lock().await;
self.metrics.remove_timedout_workers.wrap(move || {
// Items should be sorted based on last_update_timestamp, so we don't need to iterate the entire
// map most of the time.
Expand Down Expand Up @@ -986,7 +997,7 @@ impl WorkerScheduler for SimpleScheduler {
}

async fn set_drain_worker(&self, worker_id: WorkerId, is_draining: bool) -> Result<(), Error> {
let mut inner = self.get_inner_lock();
let mut inner = self.get_inner_lock().await;
inner.set_drain_worker(worker_id, is_draining)
}

Expand All @@ -1001,7 +1012,7 @@ impl MetricsComponent for SimpleScheduler {
self.metrics.gather_metrics(c);
{
// We use the raw lock because we dont gather stats about gathering stats.
let inner = self.inner.lock();
let inner = self.inner.lock_blocking();
c.publish(
"queued_actions_total",
&inner.queued_actions.len(),
Expand Down Expand Up @@ -1127,12 +1138,12 @@ impl MetricsComponent for AwaitedAction {

#[derive(Default)]
struct Metrics {
add_action: FuncCounterWrapper,
add_action: AsyncCounterWrapper,
existing_actions_found: CounterWithTime,
existing_actions_not_found: CounterWithTime,
clean_recently_completed_actions: CounterWithTime,
remove_timedout_workers: FuncCounterWrapper,
update_action: FuncCounterWrapper,
update_action: AsyncCounterWrapper,
update_action_missing_action_result: CounterWithTime,
update_action_from_wrong_worker: CounterWithTime,
update_action_no_more_listeners: CounterWithTime,
Expand Down
22 changes: 15 additions & 7 deletions nativelink-service/tests/worker_api_server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ pub mod connect_worker_tests {

let worker_exists = test_context
.scheduler
.contains_worker_for_test(&test_context.worker_id);
.contains_worker_for_test(&test_context.worker_id)
.await;
assert!(worker_exists, "Expected worker to exist in worker map");

Ok(())
Expand All @@ -151,7 +152,8 @@ pub mod keep_alive_tests {
.await?;
let worker_exists = test_context
.scheduler
.contains_worker_for_test(&test_context.worker_id);
.contains_worker_for_test(&test_context.worker_id)
.await;
assert!(worker_exists, "Expected worker to exist in worker map");
}
{
Expand All @@ -163,7 +165,8 @@ pub mod keep_alive_tests {
.await?;
let worker_exists = test_context
.scheduler
.contains_worker_for_test(&test_context.worker_id);
.contains_worker_for_test(&test_context.worker_id)
.await;
assert!(!worker_exists, "Expected worker to not exist in map");
}

Expand Down Expand Up @@ -195,7 +198,8 @@ pub mod keep_alive_tests {
.await?;
let worker_exists = test_context
.scheduler
.contains_worker_for_test(&test_context.worker_id);
.contains_worker_for_test(&test_context.worker_id)
.await;
assert!(worker_exists, "Expected worker to exist in worker map");
}
{
Expand All @@ -217,7 +221,8 @@ pub mod keep_alive_tests {
.await?;
let worker_exists = test_context
.scheduler
.contains_worker_for_test(&test_context.worker_id);
.contains_worker_for_test(&test_context.worker_id)
.await;
assert!(worker_exists, "Expected worker to exist in map");
}

Expand All @@ -234,6 +239,7 @@ pub mod keep_alive_tests {
test_context
.scheduler
.send_keep_alive_to_worker_for_test(&test_context.worker_id)
.await
.err_tip(|| "Could not send keep alive to worker")?;

{
Expand Down Expand Up @@ -269,7 +275,8 @@ pub mod going_away_tests {

let worker_exists = test_context
.scheduler
.contains_worker_for_test(&test_context.worker_id);
.contains_worker_for_test(&test_context.worker_id)
.await;
assert!(worker_exists, "Expected worker to exist in worker map");

test_context
Expand All @@ -279,7 +286,8 @@ pub mod going_away_tests {

let worker_exists = test_context
.scheduler
.contains_worker_for_test(&test_context.worker_id);
.contains_worker_for_test(&test_context.worker_id)
.await;
assert!(
!worker_exists,
"Expected worker to be removed from worker map"
Expand Down

0 comments on commit 3c50dd5

Please sign in to comment.