diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index 39dc1fa72..226ac747e 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -20,6 +20,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Instant, SystemTime}; +use async_lock::{Mutex, MutexGuard}; use async_trait::async_trait; use futures::Future; use hashbrown::{HashMap, HashSet}; @@ -37,7 +38,6 @@ use nativelink_util::metrics_utils::{ use nativelink_util::platform_properties::PlatformPropertyValue; use nativelink_util::spawn; use nativelink_util::task::JoinHandleDropGuard; -use parking_lot::{Mutex, MutexGuard}; use tokio::sync::{watch, Notify}; use tokio::time::Duration; use tracing::{event, Level}; @@ -257,7 +257,7 @@ impl SimpleSchedulerImpl { /// If the task cannot be executed immediately it will be queued for execution /// based on priority and other metrics. /// All further updates to the action will be provided through `listener`. - fn add_action( + async fn add_action( &mut self, action_info: ActionInfo, ) -> Result>, Error> { @@ -341,7 +341,7 @@ impl SimpleSchedulerImpl { .map(|action| watch::channel(action.state.clone()).1) } - fn find_existing_action( + async fn find_existing_action( &self, unique_qualifier: &ActionInfoHashKey, ) -> Option>> { @@ -444,7 +444,7 @@ impl SimpleSchedulerImpl { // TODO(blaise.bruer) This is an O(n*m) (aka n^2) algorithm. In theory we can create a map // of capabilities of each worker and then try and match the actions to the worker using // the map lookup (ie. map reduce). - fn do_try_match(&mut self) { + async fn do_try_match(&mut self) { // TODO(blaise.bruer) This is a bit difficult because of how rust's borrow checker gets in // the way. We need to conditionally remove items from the `queued_action`. Rust is working // to add `drain_filter`, which would in theory solve this problem, but because we need @@ -598,7 +598,7 @@ impl SimpleSchedulerImpl { self.tasks_or_workers_change_notify.notify_one(); } - fn update_action( + async fn update_action( &mut self, worker_id: &WorkerId, action_info_hash_key: &ActionInfoHashKey, @@ -781,9 +781,9 @@ impl SimpleScheduler { // really need to worry about this thread taking the lock // starving other threads too much. Some(inner_mux) => { - let mut inner = inner_mux.lock(); + let mut inner = inner_mux.lock().await; let timer = metrics_for_do_try_match.do_try_match.begin_timer(); - inner.do_try_match(); + inner.do_try_match().await; timer.measure(); } // If the inner went away it means the scheduler is shutting @@ -801,14 +801,17 @@ impl SimpleScheduler { /// Checks to see if the worker exists in the worker pool. Should only be used in unit tests. #[must_use] - pub fn contains_worker_for_test(&self, worker_id: &WorkerId) -> bool { - let inner = self.get_inner_lock(); + pub async fn contains_worker_for_test(&self, worker_id: &WorkerId) -> bool { + let inner = self.get_inner_lock().await; inner.workers.workers.contains(worker_id) } /// Checks to see if the worker can accept work. Should only be used in unit tests. - pub fn can_worker_accept_work_for_test(&self, worker_id: &WorkerId) -> Result { - let mut inner = self.get_inner_lock(); + pub async fn can_worker_accept_work_for_test( + &self, + worker_id: &WorkerId, + ) -> Result { + let mut inner = self.get_inner_lock().await; let worker = inner.workers.workers.get_mut(worker_id).ok_or_else(|| { make_input_err!("WorkerId '{}' does not exist in workers map", worker_id) })?; @@ -816,19 +819,22 @@ impl SimpleScheduler { } /// A unit test function used to send the keep alive message to the worker from the server. - pub fn send_keep_alive_to_worker_for_test(&self, worker_id: &WorkerId) -> Result<(), Error> { - let mut inner = self.get_inner_lock(); + pub async fn send_keep_alive_to_worker_for_test( + &self, + worker_id: &WorkerId, + ) -> Result<(), Error> { + let mut inner = self.get_inner_lock().await; let worker = inner.workers.workers.get_mut(worker_id).ok_or_else(|| { make_input_err!("WorkerId '{}' does not exist in workers map", worker_id) })?; worker.keep_alive() } - fn get_inner_lock(&self) -> MutexGuard<'_, SimpleSchedulerImpl> { + async fn get_inner_lock(&self) -> MutexGuard<'_, SimpleSchedulerImpl> { // We don't use one of the wrappers because we only want to capture the time spent, // nothing else beacuse this is a hot path. let start = Instant::now(); - let lock = self.inner.lock(); + let lock: MutexGuard = self.inner.lock().await; self.metrics .lock_stall_time .fetch_add(start.elapsed().as_nanos() as u64, Ordering::Relaxed); @@ -852,19 +858,21 @@ impl ActionScheduler for SimpleScheduler { &self, action_info: ActionInfo, ) -> Result>, Error> { - let mut inner = self.get_inner_lock(); + let mut inner = self.get_inner_lock().await; self.metrics .add_action - .wrap(move || inner.add_action(action_info)) + .wrap(inner.add_action(action_info)) + .await } async fn find_existing_action( &self, unique_qualifier: &ActionInfoHashKey, ) -> Option>> { - let inner = self.get_inner_lock(); + let inner = self.get_inner_lock().await; let result = inner .find_existing_action(unique_qualifier) + .await .or_else(|| inner.find_recently_completed_action(unique_qualifier)); if result.is_some() { self.metrics.existing_actions_found.inc(); @@ -875,7 +883,9 @@ impl ActionScheduler for SimpleScheduler { } async fn clean_recently_completed_actions(&self) { - self.get_inner_lock().clean_recently_completed_actions(); + self.get_inner_lock() + .await + .clean_recently_completed_actions(); self.metrics.clean_recently_completed_actions.inc() } @@ -892,7 +902,7 @@ impl WorkerScheduler for SimpleScheduler { async fn add_worker(&self, worker: Worker) -> Result<(), Error> { let worker_id = worker.id; - let mut inner = self.get_inner_lock(); + let mut inner = self.get_inner_lock().await; self.metrics.add_worker.wrap(move || { let res = inner .workers @@ -912,7 +922,7 @@ impl WorkerScheduler for SimpleScheduler { action_info_hash_key: &ActionInfoHashKey, err: Error, ) { - let mut inner = self.get_inner_lock(); + let mut inner = self.get_inner_lock().await; inner.update_action_with_internal_error(worker_id, action_info_hash_key, err); } @@ -922,10 +932,11 @@ impl WorkerScheduler for SimpleScheduler { action_info_hash_key: &ActionInfoHashKey, action_stage: ActionStage, ) -> Result<(), Error> { - let mut inner = self.get_inner_lock(); + let mut inner = self.get_inner_lock().await; self.metrics .update_action - .wrap(move || inner.update_action(worker_id, action_info_hash_key, action_stage)) + .wrap(inner.update_action(worker_id, action_info_hash_key, action_stage)) + .await } async fn worker_keep_alive_received( @@ -933,7 +944,7 @@ impl WorkerScheduler for SimpleScheduler { worker_id: &WorkerId, timestamp: WorkerTimestamp, ) -> Result<(), Error> { - let mut inner = self.get_inner_lock(); + let mut inner = self.get_inner_lock().await; inner .workers .refresh_lifetime(worker_id, timestamp) @@ -941,7 +952,7 @@ impl WorkerScheduler for SimpleScheduler { } async fn remove_worker(&self, worker_id: WorkerId) { - let mut inner = self.get_inner_lock(); + let mut inner = self.get_inner_lock().await; inner.immediate_evict_worker( &worker_id, make_err!(Code::Internal, "Received request to remove worker"), @@ -949,7 +960,7 @@ impl WorkerScheduler for SimpleScheduler { } async fn remove_timedout_workers(&self, now_timestamp: WorkerTimestamp) -> Result<(), Error> { - let mut inner = self.get_inner_lock(); + let mut inner = self.get_inner_lock().await; self.metrics.remove_timedout_workers.wrap(move || { // Items should be sorted based on last_update_timestamp, so we don't need to iterate the entire // map most of the time. @@ -986,7 +997,7 @@ impl WorkerScheduler for SimpleScheduler { } async fn set_drain_worker(&self, worker_id: WorkerId, is_draining: bool) -> Result<(), Error> { - let mut inner = self.get_inner_lock(); + let mut inner = self.get_inner_lock().await; inner.set_drain_worker(worker_id, is_draining) } @@ -1001,7 +1012,7 @@ impl MetricsComponent for SimpleScheduler { self.metrics.gather_metrics(c); { // We use the raw lock because we dont gather stats about gathering stats. - let inner = self.inner.lock(); + let inner = self.inner.lock_blocking(); c.publish( "queued_actions_total", &inner.queued_actions.len(), @@ -1127,12 +1138,12 @@ impl MetricsComponent for AwaitedAction { #[derive(Default)] struct Metrics { - add_action: FuncCounterWrapper, + add_action: AsyncCounterWrapper, existing_actions_found: CounterWithTime, existing_actions_not_found: CounterWithTime, clean_recently_completed_actions: CounterWithTime, remove_timedout_workers: FuncCounterWrapper, - update_action: FuncCounterWrapper, + update_action: AsyncCounterWrapper, update_action_missing_action_result: CounterWithTime, update_action_from_wrong_worker: CounterWithTime, update_action_no_more_listeners: CounterWithTime, diff --git a/nativelink-service/tests/worker_api_server_test.rs b/nativelink-service/tests/worker_api_server_test.rs index 22e78c990..b614cdea3 100644 --- a/nativelink-service/tests/worker_api_server_test.rs +++ b/nativelink-service/tests/worker_api_server_test.rs @@ -124,7 +124,8 @@ pub mod connect_worker_tests { let worker_exists = test_context .scheduler - .contains_worker_for_test(&test_context.worker_id); + .contains_worker_for_test(&test_context.worker_id) + .await; assert!(worker_exists, "Expected worker to exist in worker map"); Ok(()) @@ -151,7 +152,8 @@ pub mod keep_alive_tests { .await?; let worker_exists = test_context .scheduler - .contains_worker_for_test(&test_context.worker_id); + .contains_worker_for_test(&test_context.worker_id) + .await; assert!(worker_exists, "Expected worker to exist in worker map"); } { @@ -163,7 +165,8 @@ pub mod keep_alive_tests { .await?; let worker_exists = test_context .scheduler - .contains_worker_for_test(&test_context.worker_id); + .contains_worker_for_test(&test_context.worker_id) + .await; assert!(!worker_exists, "Expected worker to not exist in map"); } @@ -195,7 +198,8 @@ pub mod keep_alive_tests { .await?; let worker_exists = test_context .scheduler - .contains_worker_for_test(&test_context.worker_id); + .contains_worker_for_test(&test_context.worker_id) + .await; assert!(worker_exists, "Expected worker to exist in worker map"); } { @@ -217,7 +221,8 @@ pub mod keep_alive_tests { .await?; let worker_exists = test_context .scheduler - .contains_worker_for_test(&test_context.worker_id); + .contains_worker_for_test(&test_context.worker_id) + .await; assert!(worker_exists, "Expected worker to exist in map"); } @@ -234,6 +239,7 @@ pub mod keep_alive_tests { test_context .scheduler .send_keep_alive_to_worker_for_test(&test_context.worker_id) + .await .err_tip(|| "Could not send keep alive to worker")?; { @@ -269,7 +275,8 @@ pub mod going_away_tests { let worker_exists = test_context .scheduler - .contains_worker_for_test(&test_context.worker_id); + .contains_worker_for_test(&test_context.worker_id) + .await; assert!(worker_exists, "Expected worker to exist in worker map"); test_context @@ -279,7 +286,8 @@ pub mod going_away_tests { let worker_exists = test_context .scheduler - .contains_worker_for_test(&test_context.worker_id); + .contains_worker_for_test(&test_context.worker_id) + .await; assert!( !worker_exists, "Expected worker to be removed from worker map"