From 2b8f1ee4f1078afb47f1d012ad8a347e752817db Mon Sep 17 00:00:00 2001 From: "Nathan (Blaise) Bruer" Date: Wed, 7 Aug 2024 00:34:22 -0500 Subject: [PATCH] Migrate much of the ActionScheduler API to ClientStateManager API (#1241) Mostly a cosmetic change to move the compatible parts of ActionSchedulers to use ClientStateManager API instead and implement all related requirements to existing schedulers. towards #1213 --- nativelink-scheduler/BUILD.bazel | 1 + nativelink-scheduler/src/action_scheduler.rs | 20 ++---- .../src/cache_lookup_scheduler.rs | 60 ++++++++++------ nativelink-scheduler/src/grpc_scheduler.rs | 68 +++++++++++++++---- .../src/property_modifier_scheduler.rs | 58 ++++++++++++---- nativelink-scheduler/src/simple_scheduler.rs | 64 +++++++---------- .../tests/cache_lookup_scheduler_test.rs | 28 +++++--- .../tests/property_modifier_scheduler_test.rs | 44 ++++++++---- .../tests/simple_scheduler_test.rs | 46 +++++++++---- .../tests/utils/mock_scheduler.rs | 39 ++++++----- .../tests/utils/scheduler_utils.rs | 6 +- nativelink-service/src/execution_server.rs | 13 ++-- .../src/operation_state_manager.rs | 6 +- 13 files changed, 287 insertions(+), 166 deletions(-) diff --git a/nativelink-scheduler/BUILD.bazel b/nativelink-scheduler/BUILD.bazel index 78964fb08..218013cc0 100644 --- a/nativelink-scheduler/BUILD.bazel +++ b/nativelink-scheduler/BUILD.bazel @@ -82,6 +82,7 @@ rust_test_suite( "@crates//:pretty_assertions", "@crates//:prost", "@crates//:tokio", + "@crates//:tokio-stream", "@crates//:uuid", ], ) diff --git a/nativelink-scheduler/src/action_scheduler.rs b/nativelink-scheduler/src/action_scheduler.rs index 2712110b7..bf5669e78 100644 --- a/nativelink-scheduler/src/action_scheduler.rs +++ b/nativelink-scheduler/src/action_scheduler.rs @@ -17,31 +17,19 @@ use std::sync::Arc; use async_trait::async_trait; use nativelink_error::Error; use nativelink_metric::RootMetricsComponent; -use nativelink_util::action_messages::{ActionInfo, OperationId}; -use nativelink_util::operation_state_manager::ActionStateResult; +use nativelink_util::operation_state_manager::ClientStateManager; use crate::platform_property_manager::PlatformPropertyManager; /// ActionScheduler interface is responsible for interactions between the scheduler /// and action related operations. #[async_trait] -pub trait ActionScheduler: Sync + Send + Unpin + RootMetricsComponent + 'static { +pub trait ActionScheduler: + ClientStateManager + Sync + Send + Unpin + RootMetricsComponent + 'static +{ /// Returns the platform property manager. async fn get_platform_property_manager( &self, instance_name: &str, ) -> Result, Error>; - - /// Adds an action to the scheduler for remote execution. - async fn add_action( - &self, - client_operation_id: OperationId, - action_info: ActionInfo, - ) -> Result, Error>; - - /// Find an existing action by its name. - async fn find_by_client_operation_id( - &self, - client_operation_id: &OperationId, - ) -> Result>, Error>; } diff --git a/nativelink-scheduler/src/cache_lookup_scheduler.rs b/nativelink-scheduler/src/cache_lookup_scheduler.rs index 7e60b51a6..a462131b6 100644 --- a/nativelink-scheduler/src/cache_lookup_scheduler.rs +++ b/nativelink-scheduler/src/cache_lookup_scheduler.rs @@ -29,7 +29,9 @@ use nativelink_util::action_messages::{ use nativelink_util::background_spawn; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; -use nativelink_util::operation_state_manager::ActionStateResult; +use nativelink_util::operation_state_manager::{ + ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter, +}; use nativelink_util::store_trait::Store; use parking_lot::{Mutex, MutexGuard}; use scopeguard::guard; @@ -147,23 +149,11 @@ impl CacheLookupScheduler { inflight_cache_checks: Default::default(), }) } -} -#[async_trait] -impl ActionScheduler for CacheLookupScheduler { - async fn get_platform_property_manager( - &self, - instance_name: &str, - ) -> Result, Error> { - self.action_scheduler - .get_platform_property_manager(instance_name) - .await - } - - async fn add_action( + async fn inner_add_action( &self, client_operation_id: OperationId, - action_info: ActionInfo, + action_info: Arc, ) -> Result, Error> { let unique_key = match &action_info.unique_qualifier { ActionUniqueQualifier::Cachable(unique_key) => unique_key.clone(), @@ -320,14 +310,46 @@ impl ActionScheduler for CacheLookupScheduler { .err_tip(|| "In CacheLookupScheduler::add_action") } - async fn find_by_client_operation_id( + async fn inner_filter_operations( + &self, + filter: OperationFilter, + ) -> Result { + self.action_scheduler + .filter_operations(filter) + .await + .err_tip(|| "In CacheLookupScheduler::filter_operations") + } +} + +#[async_trait] +impl ActionScheduler for CacheLookupScheduler { + async fn get_platform_property_manager( &self, - client_operation_id: &OperationId, - ) -> Result>, Error> { + instance_name: &str, + ) -> Result, Error> { self.action_scheduler - .find_by_client_operation_id(client_operation_id) + .get_platform_property_manager(instance_name) .await } } +#[async_trait] +impl ClientStateManager for CacheLookupScheduler { + async fn add_action( + &self, + client_operation_id: OperationId, + action_info: Arc, + ) -> Result, Error> { + self.inner_add_action(client_operation_id, action_info) + .await + } + + async fn filter_operations( + &self, + filter: OperationFilter, + ) -> Result { + self.inner_filter_operations(filter).await + } +} + impl RootMetricsComponent for CacheLookupScheduler {} diff --git a/nativelink-scheduler/src/grpc_scheduler.rs b/nativelink-scheduler/src/grpc_scheduler.rs index 3e8eae3dd..a8793d14f 100644 --- a/nativelink-scheduler/src/grpc_scheduler.rs +++ b/nativelink-scheduler/src/grpc_scheduler.rs @@ -19,8 +19,8 @@ use std::time::Duration; use async_trait::async_trait; use futures::stream::unfold; -use futures::TryFutureExt; -use nativelink_error::{make_err, Code, Error, ResultExt}; +use futures::{StreamExt, TryFutureExt}; +use nativelink_error::{error_if, make_err, Code, Error, ResultExt}; use nativelink_metric::{MetricsComponent, RootMetricsComponent}; use nativelink_proto::build::bazel::remote::execution::v2::capabilities_client::CapabilitiesClient; use nativelink_proto::build::bazel::remote::execution::v2::execution_client::ExecutionClient; @@ -32,7 +32,9 @@ use nativelink_util::action_messages::{ ActionInfo, ActionState, ActionUniqueQualifier, OperationId, DEFAULT_EXECUTION_PRIORITY, }; use nativelink_util::connection_manager::ConnectionManager; -use nativelink_util::operation_state_manager::ActionStateResult; +use nativelink_util::operation_state_manager::{ + ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter, +}; use nativelink_util::retry::{Retrier, RetryResult}; use nativelink_util::{background_spawn, tls_utils}; use parking_lot::Mutex; @@ -217,11 +219,8 @@ impl GrpcScheduler { "Upstream scheduler didn't accept action." )) } -} -#[async_trait] -impl ActionScheduler for GrpcScheduler { - async fn get_platform_property_manager( + async fn inner_get_platform_property_manager( &self, instance_name: &str, ) -> Result, Error> { @@ -268,10 +267,10 @@ impl ActionScheduler for GrpcScheduler { .await } - async fn add_action( + async fn inner_add_action( &self, _client_operation_id: OperationId, - action_info: ActionInfo, + action_info: Arc, ) -> Result, Error> { let execution_policy = if action_info.priority == DEFAULT_EXECUTION_PRIORITY { None @@ -314,10 +313,17 @@ impl ActionScheduler for GrpcScheduler { Self::stream_state(result_stream).await } - async fn find_by_client_operation_id( + async fn inner_filter_operations( &self, - client_operation_id: &OperationId, - ) -> Result>, Error> { + filter: OperationFilter, + ) -> Result { + error_if!(filter != OperationFilter { + client_operation_id: filter.client_operation_id.clone(), + ..Default::default() + }, "Unsupported filter in GrpcScheduler::filter_operations. Only client_operation_id is supported - {filter:?}"); + let client_operation_id = filter.client_operation_id.ok_or_else(|| { + make_err!(Code::InvalidArgument, "`client_operation_id` is the only supported filter in GrpcScheduler::filter_operations") + })?; let request = WaitExecutionRequest { name: client_operation_id.to_string(), }; @@ -336,17 +342,51 @@ impl ActionScheduler for GrpcScheduler { .and_then(|result_stream| Self::stream_state(result_stream.into_inner())) .await; match result_stream { - Ok(result_stream) => Ok(Some(result_stream)), + Ok(result_stream) => Ok(unfold( + Some(result_stream), + |maybe_result_stream| async move { maybe_result_stream.map(|v| (v, None)) }, + ) + .boxed()), Err(err) => { event!( Level::WARN, ?err, "Error looking up action with upstream scheduler" ); - Ok(None) + Ok(futures::stream::empty().boxed()) } } } } +#[async_trait] +impl ClientStateManager for GrpcScheduler { + async fn add_action( + &self, + client_operation_id: OperationId, + action_info: Arc, + ) -> Result, Error> { + self.inner_add_action(client_operation_id, action_info) + .await + } + + async fn filter_operations<'a>( + &'a self, + filter: OperationFilter, + ) -> Result, Error> { + self.inner_filter_operations(filter).await + } +} + +#[async_trait] +impl ActionScheduler for GrpcScheduler { + async fn get_platform_property_manager( + &self, + instance_name: &str, + ) -> Result, Error> { + self.inner_get_platform_property_manager(instance_name) + .await + } +} + impl RootMetricsComponent for GrpcScheduler {} diff --git a/nativelink-scheduler/src/property_modifier_scheduler.rs b/nativelink-scheduler/src/property_modifier_scheduler.rs index 09b75c4ed..83338fe06 100644 --- a/nativelink-scheduler/src/property_modifier_scheduler.rs +++ b/nativelink-scheduler/src/property_modifier_scheduler.rs @@ -21,7 +21,9 @@ use nativelink_config::schedulers::{PropertyModification, PropertyType}; use nativelink_error::{Error, ResultExt}; use nativelink_metric::{MetricsComponent, RootMetricsComponent}; use nativelink_util::action_messages::{ActionInfo, OperationId}; -use nativelink_util::operation_state_manager::ActionStateResult; +use nativelink_util::operation_state_manager::{ + ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter, +}; use parking_lot::Mutex; use crate::action_scheduler::ActionScheduler; @@ -47,11 +49,8 @@ impl PropertyModifierScheduler { property_managers: Mutex::new(HashMap::new()), } } -} -#[async_trait] -impl ActionScheduler for PropertyModifierScheduler { - async fn get_platform_property_manager( + async fn inner_get_platform_property_manager( &self, instance_name: &str, ) -> Result, Error> { @@ -91,19 +90,20 @@ impl ActionScheduler for PropertyModifierScheduler { Ok(property_manager) } - async fn add_action( + async fn inner_add_action( &self, client_operation_id: OperationId, - mut action_info: ActionInfo, + mut action_info: Arc, ) -> Result, Error> { let platform_property_manager = self - .get_platform_property_manager(action_info.unique_qualifier.instance_name()) + .inner_get_platform_property_manager(action_info.unique_qualifier.instance_name()) .await .err_tip(|| "In PropertyModifierScheduler::add_action")?; + let action_info_mut = Arc::make_mut(&mut action_info); for modification in &self.modifications { match modification { PropertyModification::add(addition) => { - action_info.platform_properties.properties.insert( + action_info_mut.platform_properties.properties.insert( addition.name.clone(), platform_property_manager .make_prop_value(&addition.name, &addition.value) @@ -111,7 +111,7 @@ impl ActionScheduler for PropertyModifierScheduler { ) } PropertyModification::remove(name) => { - action_info.platform_properties.properties.remove(name) + action_info_mut.platform_properties.properties.remove(name) } }; } @@ -120,14 +120,42 @@ impl ActionScheduler for PropertyModifierScheduler { .await } - async fn find_by_client_operation_id( + async fn inner_filter_operations( &self, - client_operation_id: &OperationId, - ) -> Result>, Error> { - self.scheduler - .find_by_client_operation_id(client_operation_id) + filter: OperationFilter, + ) -> Result { + self.scheduler.filter_operations(filter).await + } +} + +#[async_trait] +impl ActionScheduler for PropertyModifierScheduler { + async fn get_platform_property_manager( + &self, + instance_name: &str, + ) -> Result, Error> { + self.inner_get_platform_property_manager(instance_name) .await } } +#[async_trait] +impl ClientStateManager for PropertyModifierScheduler { + async fn add_action( + &self, + client_operation_id: OperationId, + action_info: Arc, + ) -> Result, Error> { + self.inner_add_action(client_operation_id, action_info) + .await + } + + async fn filter_operations<'a>( + &'a self, + filter: OperationFilter, + ) -> Result, Error> { + self.inner_filter_operations(filter).await + } +} + impl RootMetricsComponent for PropertyModifierScheduler {} diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index f20cd5deb..5fc34a1cd 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -139,7 +139,7 @@ impl SimpleScheduler { /// for execution based on priority and other metrics. /// All further updates to the action will be provided through the returned /// value. - async fn add_action( + async fn inner_add_action( &self, client_operation_id: OperationId, action_info: Arc, @@ -155,25 +155,14 @@ impl SimpleScheduler { ))) } - async fn find_by_client_operation_id( + async fn inner_filter_operations( &self, - client_operation_id: &OperationId, - ) -> Result>, Error> { - let filter = OperationFilter { - client_operation_id: Some(client_operation_id.clone()), - ..Default::default() - }; - let filter_result = self.client_state_manager.filter_operations(filter).await; - - let mut stream = filter_result - .err_tip(|| "In SimpleScheduler::find_by_client_operation_id getting filter result")?; - let Some(action_state_result) = stream.next().await else { - return Ok(None); - }; - Ok(Some(Box::new(SimpleSchedulerActionStateResult::new( - client_operation_id.clone(), - action_state_result, - )))) + filter: OperationFilter, + ) -> Result { + self.client_state_manager + .filter_operations(filter) + .await + .err_tip(|| "In SimpleScheduler::find_by_client_operation_id getting filter result") } async fn get_queued_operations(&self) -> Result { @@ -370,34 +359,31 @@ impl SimpleScheduler { } #[async_trait] -impl ActionScheduler for SimpleScheduler { - async fn get_platform_property_manager( - &self, - _instance_name: &str, - ) -> Result, Error> { - Ok(self.platform_property_manager.clone()) - } - +impl ClientStateManager for SimpleScheduler { async fn add_action( &self, client_operation_id: OperationId, - action_info: ActionInfo, + action_info: Arc, ) -> Result, Error> { - self.add_action(client_operation_id, Arc::new(action_info)) + self.inner_add_action(client_operation_id, action_info) .await } - async fn find_by_client_operation_id( + async fn filter_operations<'a>( + &'a self, + filter: OperationFilter, + ) -> Result, Error> { + self.inner_filter_operations(filter).await + } +} + +#[async_trait] +impl ActionScheduler for SimpleScheduler { + async fn get_platform_property_manager( &self, - client_operation_id: &OperationId, - ) -> Result>, Error> { - let maybe_receiver = self - .find_by_client_operation_id(client_operation_id) - .await - .err_tip(|| { - format!("Error while finding action with client id: {client_operation_id:?}") - })?; - Ok(maybe_receiver) + _instance_name: &str, + ) -> Result, Error> { + Ok(self.platform_property_manager.clone()) } } diff --git a/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs b/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs index 097ea4d72..46953de36 100644 --- a/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs +++ b/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs @@ -33,11 +33,13 @@ use nativelink_util::action_messages::{ ActionResult, ActionStage, ActionState, ActionUniqueQualifier, OperationId, }; use nativelink_util::common::DigestInfo; +use nativelink_util::operation_state_manager::{ClientStateManager, OperationFilter}; use nativelink_util::store_trait::{Store, StoreLike}; use pretty_assertions::assert_eq; use prost::Message; use tokio::sync::watch; use tokio::{self}; +use tokio_stream::StreamExt; use utils::mock_scheduler::MockActionScheduler; use utils::scheduler_utils::{make_base_action_info, TokioWatchActionStateResult, INSTANCE_NAME}; @@ -99,8 +101,9 @@ async fn add_action_handles_skip_cache() -> Result<(), Error> { let ActionUniqueQualifier::Cachable(action_key) = action_info.unique_qualifier.clone() else { panic!("This test should be testing when item was cached first"); }; - let mut skip_cache_action = action_info.clone(); + let mut skip_cache_action = action_info.as_ref().clone(); skip_cache_action.unique_qualifier = ActionUniqueQualifier::Uncachable(action_key); + let skip_cache_action = Arc::new(skip_cache_action); let client_operation_id = OperationId::default(); let _ = join!( context @@ -110,7 +113,7 @@ async fn add_action_handles_skip_cache() -> Result<(), Error> { .mock_scheduler .expect_add_action(Ok(Box::new(TokioWatchActionStateResult::new( client_operation_id, - Arc::new(action_info), + action_info, forward_watch_channel_rx )))) ); @@ -121,15 +124,22 @@ async fn add_action_handles_skip_cache() -> Result<(), Error> { async fn find_by_client_operation_id_call_passed() -> Result<(), Error> { let context = make_cache_scheduler()?; let client_operation_id = OperationId::default(); - let (actual_result, actual_client_id) = join!( - context - .cache_scheduler - .find_by_client_operation_id(&client_operation_id), + let (actual_result, actual_filter) = join!( + context.cache_scheduler.filter_operations(OperationFilter { + client_operation_id: Some(client_operation_id.clone()), + ..Default::default() + }), context .mock_scheduler - .expect_find_by_client_operation_id(Ok(None)), + .expect_filter_operations(Ok(Box::pin(futures::stream::empty()))), + ); + assert_eq!(true, actual_result.unwrap().next().await.is_none()); + assert_eq!( + OperationFilter { + client_operation_id: Some(client_operation_id), + ..Default::default() + }, + actual_filter ); - assert_eq!(true, actual_result.unwrap().is_none()); - assert_eq!(client_operation_id, actual_client_id); Ok(()) } diff --git a/nativelink-scheduler/tests/property_modifier_scheduler_test.rs b/nativelink-scheduler/tests/property_modifier_scheduler_test.rs index f3a0c97ed..6d984d2df 100644 --- a/nativelink-scheduler/tests/property_modifier_scheduler_test.rs +++ b/nativelink-scheduler/tests/property_modifier_scheduler_test.rs @@ -21,7 +21,7 @@ mod utils { pub(crate) mod scheduler_utils; } -use futures::join; +use futures::{join, StreamExt}; use nativelink_config::schedulers::{PlatformPropertyAddition, PropertyModification, PropertyType}; use nativelink_error::Error; use nativelink_macro::nativelink_test; @@ -30,6 +30,7 @@ use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; use nativelink_scheduler::property_modifier_scheduler::PropertyModifierScheduler; use nativelink_util::action_messages::{ActionStage, ActionState, OperationId}; use nativelink_util::common::DigestInfo; +use nativelink_util::operation_state_manager::{ClientStateManager, OperationFilter}; use nativelink_util::platform_properties::PlatformPropertyValue; use pretty_assertions::assert_eq; use tokio::sync::watch; @@ -88,7 +89,7 @@ async fn add_action_adds_property() -> Result<(), Error> { .mock_scheduler .expect_add_action(Ok(Box::new(TokioWatchActionStateResult::new( client_operation_id.clone(), - Arc::new(action_info), + action_info, forward_watch_channel_rx )))), ); @@ -110,11 +111,14 @@ async fn add_action_overwrites_property() -> Result<(), Error> { name: name.clone(), value: replaced_value.clone(), })]); - let mut action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); + let mut action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()) + .as_ref() + .clone(); action_info .platform_properties .properties .insert(name.clone(), PlatformPropertyValue::Unknown(original_value)); + let action_info = Arc::new(action_info); let (_forward_watch_channel_tx, forward_watch_channel_rx) = watch::channel(Arc::new(ActionState { operation_id: OperationId::default(), @@ -137,7 +141,7 @@ async fn add_action_overwrites_property() -> Result<(), Error> { .mock_scheduler .expect_add_action(Ok(Box::new(TokioWatchActionStateResult::new( client_operation_id.clone(), - Arc::new(action_info), + action_info, forward_watch_channel_rx )))), ); @@ -183,7 +187,7 @@ async fn add_action_property_added_after_remove() -> Result<(), Error> { .mock_scheduler .expect_add_action(Ok(Box::new(TokioWatchActionStateResult::new( client_operation_id.clone(), - Arc::new(action_info), + action_info, forward_watch_channel_rx )))), ); @@ -229,7 +233,7 @@ async fn add_action_property_remove_after_add() -> Result<(), Error> { .mock_scheduler .expect_add_action(Ok(Box::new(TokioWatchActionStateResult::new( client_operation_id.clone(), - Arc::new(action_info), + action_info, forward_watch_channel_rx )))), ); @@ -246,11 +250,14 @@ async fn add_action_property_remove() -> Result<(), Error> { let name = "name".to_string(); let value = "value".to_string(); let context = make_modifier_scheduler(vec![PropertyModification::remove(name.clone())]); - let mut action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); + let mut action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()) + .as_ref() + .clone(); action_info .platform_properties .properties .insert(name, PlatformPropertyValue::Unknown(value)); + let action_info = Arc::new(action_info); let (_forward_watch_channel_tx, forward_watch_channel_rx) = watch::channel(Arc::new(ActionState { operation_id: OperationId::default(), @@ -270,7 +277,7 @@ async fn add_action_property_remove() -> Result<(), Error> { .mock_scheduler .expect_add_action(Ok(Box::new(TokioWatchActionStateResult::new( client_operation_id.clone(), - Arc::new(action_info), + action_info, forward_watch_channel_rx )))), ); @@ -285,17 +292,26 @@ async fn add_action_property_remove() -> Result<(), Error> { #[nativelink_test] async fn find_by_client_operation_id_call_passed() -> Result<(), Error> { let context = make_modifier_scheduler(vec![]); - let operation_id = OperationId::default(); - let (actual_result, actual_operation_id) = join!( + let client_operation_id = OperationId::default(); + let (actual_result, actual_filter) = join!( context .modifier_scheduler - .find_by_client_operation_id(&operation_id), + .filter_operations(OperationFilter { + client_operation_id: Some(client_operation_id.clone()), + ..Default::default() + }), context .mock_scheduler - .expect_find_by_client_operation_id(Ok(None)), + .expect_filter_operations(Ok(Box::pin(futures::stream::empty()))), + ); + assert_eq!(true, actual_result.unwrap().next().await.is_none()); + assert_eq!( + OperationFilter { + client_operation_id: Some(client_operation_id), + ..Default::default() + }, + actual_filter ); - assert_eq!(true, actual_result.unwrap().is_none()); - assert_eq!(operation_id, actual_operation_id); Ok(()) } diff --git a/nativelink-scheduler/tests/simple_scheduler_test.rs b/nativelink-scheduler/tests/simple_scheduler_test.rs index cd5bbb480..f0d1e70b1 100644 --- a/nativelink-scheduler/tests/simple_scheduler_test.rs +++ b/nativelink-scheduler/tests/simple_scheduler_test.rs @@ -17,8 +17,8 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use futures::poll; use futures::task::Poll; +use futures::{poll, StreamExt}; use mock_instant::MockClock; use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_macro::nativelink_test; @@ -26,7 +26,6 @@ use nativelink_proto::build::bazel::remote::execution::v2::{digest_function, Exe use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ update_for_worker, ConnectionResult, StartExecute, UpdateForWorker, }; -use nativelink_scheduler::action_scheduler::ActionScheduler; use nativelink_scheduler::simple_scheduler::SimpleScheduler; use nativelink_scheduler::worker::Worker; use nativelink_scheduler::worker_scheduler::WorkerScheduler; @@ -36,7 +35,9 @@ use nativelink_util::action_messages::{ }; use nativelink_util::common::DigestInfo; use nativelink_util::instant_wrapper::MockInstantWrapped; -use nativelink_util::operation_state_manager::ActionStateResult; +use nativelink_util::operation_state_manager::{ + ActionStateResult, ClientStateManager, OperationFilter, +}; use nativelink_util::platform_properties::{PlatformProperties, PlatformPropertyValue}; use pretty_assertions::assert_eq; use tokio::sync::mpsc; @@ -133,7 +134,7 @@ async fn setup_action( insert_timestamp: SystemTime, ) -> Result, Error> { let mut action_info = make_base_action_info(insert_timestamp, action_digest); - action_info.platform_properties = platform_properties; + Arc::make_mut(&mut action_info).platform_properties = platform_properties; let client_id = OperationId::default(); let result = scheduler.add_action(client_id, action_info).await; tokio::task::yield_now().await; // Allow task<->worker matcher to run. @@ -230,9 +231,14 @@ async fn find_executing_action() -> Result<(), Error> { // Drop our receiver and look up a new one. drop(action_listener); let mut action_listener = scheduler - .find_by_client_operation_id(&client_operation_id) + .filter_operations(OperationFilter { + client_operation_id: Some(client_operation_id.clone()), + ..Default::default() + }) .await .unwrap() + .next() + .await .expect("Action not found"); { @@ -1062,9 +1068,14 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E // Now look up a channel after the action has completed. let mut action_listener = scheduler - .find_by_client_operation_id(&client_id) + .filter_operations(OperationFilter { + client_operation_id: Some(client_id.clone()), + ..Default::default() + }) .await .unwrap() + .next() + .await .expect("Action not found"); { // Client should get notification saying it has been completed. @@ -1784,9 +1795,14 @@ async fn client_reconnect_keeps_action_alive() -> Result<(), Error> { drop(action_listener); let mut new_action_listener = scheduler - .find_by_client_operation_id(&client_id) + .filter_operations(OperationFilter { + client_operation_id: Some(client_id.clone()), + ..Default::default() + }) .await .unwrap() + .next() + .await .expect("Action not found"); // We should get one notification saying it's queued. @@ -1807,12 +1823,18 @@ async fn client_reconnect_keeps_action_alive() -> Result<(), Error> { // Eviction happens when someone touches the internal // evicting map. So we constantly ask for some other client // to trigger eviction logic. - scheduler - .find_by_client_operation_id(&OperationId::from_raw_string( - "dummy_client_id".to_string(), - )) + assert!(scheduler + .filter_operations(OperationFilter { + client_operation_id: Some(OperationId::from_raw_string( + "dummy_client_id".to_string(), + )), + ..Default::default() + }) + .await + .unwrap() + .next() .await - .unwrap(); + .is_none()); } Ok(()) diff --git a/nativelink-scheduler/tests/utils/mock_scheduler.rs b/nativelink-scheduler/tests/utils/mock_scheduler.rs index d1c9469cb..b3cc7043a 100644 --- a/nativelink-scheduler/tests/utils/mock_scheduler.rs +++ b/nativelink-scheduler/tests/utils/mock_scheduler.rs @@ -21,7 +21,9 @@ use nativelink_scheduler::action_scheduler::ActionScheduler; use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; use nativelink_util::{ action_messages::{ActionInfo, OperationId}, - operation_state_manager::ActionStateResult, + operation_state_manager::{ + ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter, + }, }; use tokio::sync::{mpsc, Mutex}; @@ -29,13 +31,13 @@ use tokio::sync::{mpsc, Mutex}; enum ActionSchedulerCalls { GetPlatformPropertyManager(String), AddAction((OperationId, ActionInfo)), - FindExistingAction(OperationId), + FilterOperations(OperationFilter), } enum ActionSchedulerReturns { GetPlatformPropertyManager(Result, Error>), AddAction(Result, Error>), - FindExistingAction(Result>, Error>), + FilterOperations(Result, Error>), } #[derive(MetricsComponent)] @@ -103,12 +105,12 @@ impl MockActionScheduler { req } - pub async fn expect_find_by_client_operation_id( + pub async fn expect_filter_operations( &self, - result: Result>, Error>, - ) -> OperationId { + result: Result, Error>, + ) -> OperationFilter { let mut rx_call_lock = self.rx_call.lock().await; - let ActionSchedulerCalls::FindExistingAction(req) = rx_call_lock + let ActionSchedulerCalls::FilterOperations(req) = rx_call_lock .recv() .await .expect("Could not receive msg in mpsc") @@ -116,7 +118,7 @@ impl MockActionScheduler { panic!("Got incorrect call waiting for find_by_client_operation_id") }; self.tx_resp - .send(ActionSchedulerReturns::FindExistingAction(result)) + .send(ActionSchedulerReturns::FilterOperations(result)) .map_err(|_| make_input_err!("Could not send request to mpsc")) .unwrap(); req @@ -144,16 +146,19 @@ impl ActionScheduler for MockActionScheduler { _ => panic!("Expected get_platform_property_manager return value"), } } +} +#[async_trait] +impl ClientStateManager for MockActionScheduler { async fn add_action( &self, client_operation_id: OperationId, - action_info: ActionInfo, + action_info: Arc, ) -> Result, Error> { self.tx_call .send(ActionSchedulerCalls::AddAction(( client_operation_id, - action_info, + action_info.as_ref().clone(), ))) .expect("Could not send request to mpsc"); let mut rx_resp_lock = self.rx_resp.lock().await; @@ -167,14 +172,12 @@ impl ActionScheduler for MockActionScheduler { } } - async fn find_by_client_operation_id( - &self, - client_operation_id: &OperationId, - ) -> Result>, Error> { + async fn filter_operations<'a>( + &'a self, + filter: OperationFilter, + ) -> Result, Error> { self.tx_call - .send(ActionSchedulerCalls::FindExistingAction( - client_operation_id.clone(), - )) + .send(ActionSchedulerCalls::FilterOperations(filter)) .expect("Could not send request to mpsc"); let mut rx_resp_lock = self.rx_resp.lock().await; match rx_resp_lock @@ -182,7 +185,7 @@ impl ActionScheduler for MockActionScheduler { .await .expect("Could not receive msg in mpsc") { - ActionSchedulerReturns::FindExistingAction(result) => result, + ActionSchedulerReturns::FilterOperations(result) => result, _ => panic!("Expected find_by_client_operation_id return value"), } } diff --git a/nativelink-scheduler/tests/utils/scheduler_utils.rs b/nativelink-scheduler/tests/utils/scheduler_utils.rs index 992f540ac..66552dd4f 100644 --- a/nativelink-scheduler/tests/utils/scheduler_utils.rs +++ b/nativelink-scheduler/tests/utils/scheduler_utils.rs @@ -32,8 +32,8 @@ pub const INSTANCE_NAME: &str = "foobar_instance_name"; pub fn make_base_action_info( insert_timestamp: SystemTime, action_digest: DigestInfo, -) -> ActionInfo { - ActionInfo { +) -> Arc { + Arc::new(ActionInfo { command_digest: DigestInfo::new([0u8; 32], 0), input_root_digest: DigestInfo::new([0u8; 32], 0), timeout: Duration::MAX, @@ -48,7 +48,7 @@ pub fn make_base_action_info( digest_function: DigestHasherFunc::Sha256, digest: action_digest, }), - } + }) } pub struct TokioWatchActionStateResult { diff --git a/nativelink-service/src/execution_server.rs b/nativelink-service/src/execution_server.rs index 3849d56c2..d19d38c65 100644 --- a/nativelink-service/src/execution_server.rs +++ b/nativelink-service/src/execution_server.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use futures::stream::unfold; -use futures::Stream; +use futures::{Stream, StreamExt}; use nativelink_config::cas_server::{ExecutionConfig, InstanceName}; use nativelink_error::{make_input_err, Error, ResultExt}; use nativelink_proto::build::bazel::remote::execution::v2::execution_server::{ @@ -36,7 +36,7 @@ use nativelink_util::action_messages::{ }; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::{make_ctx_for_hash_func, DigestHasherFunc}; -use nativelink_util::operation_state_manager::ActionStateResult; +use nativelink_util::operation_state_manager::{ActionStateResult, OperationFilter}; use nativelink_util::platform_properties::PlatformProperties; use nativelink_util::store_trait::Store; use tonic::{Request, Response, Status}; @@ -283,7 +283,7 @@ impl ExecutionServer { let action_listener = instance_info .scheduler - .add_action(OperationId::default(), action_info) + .add_action(OperationId::default(), Arc::new(action_info)) .await .err_tip(|| "Failed to schedule task")?; @@ -316,9 +316,14 @@ impl ExecutionServer { }; let Some(rx) = instance_info .scheduler - .find_by_client_operation_id(&client_operation_id) + .filter_operations(OperationFilter { + client_operation_id: Some(client_operation_id.clone()), + ..Default::default() + }) .await .err_tip(|| "Error running find_existing_action in ExecutionServer::wait_execution")? + .next() + .await else { return Err(Status::not_found("Failed to find existing task")); }; diff --git a/nativelink-util/src/operation_state_manager.rs b/nativelink-util/src/operation_state_manager.rs index 237b4bc79..7cede96f2 100644 --- a/nativelink-util/src/operation_state_manager.rs +++ b/nativelink-util/src/operation_state_manager.rs @@ -106,10 +106,10 @@ pub trait ClientStateManager: Sync + Send + MetricsComponent { ) -> Result, Error>; /// Returns a stream of operations that match the filter. - async fn filter_operations<'a>( - &'a self, + async fn filter_operations( + &self, filter: OperationFilter, - ) -> Result, Error>; + ) -> Result; } #[async_trait]