diff --git a/nativelink-config/src/cas_server.rs b/nativelink-config/src/cas_server.rs index 20cafc6ef..ad545ea9f 100644 --- a/nativelink-config/src/cas_server.rs +++ b/nativelink-config/src/cas_server.rs @@ -600,6 +600,15 @@ pub struct LocalWorkerConfig { /// of the environment variable being the value of the property of the /// action being executed of that name or the fixed value. pub additional_environment: Option>, + + /// The maximum number of actions a worker can executions . + /// After this limit is reached, the nativelink binary will exit. + /// + /// For Example: + /// If you would like for each individual action to spin up a kubernetes + /// pod and then exit on completion, you would set this value to 1. + /// Default: None (no limit) + pub actions_before_termination: Option, } #[allow(non_camel_case_types)] diff --git a/nativelink-worker/src/local_worker.rs b/nativelink-worker/src/local_worker.rs index 2812681a0..abe99a7c5 100644 --- a/nativelink-worker/src/local_worker.rs +++ b/nativelink-worker/src/local_worker.rs @@ -27,7 +27,7 @@ use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt}; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_worker::Update; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::worker_api_client::WorkerApiClient; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ - execute_result, ExecuteResult, KeepAliveRequest, UpdateForWorker, + execute_result, ExecuteResult, GoingAwayRequest, KeepAliveRequest, UpdateForWorker, }; use nativelink_store::fast_slow_store::FastSlowStore; use nativelink_util::action_messages::{ActionResult, ActionStage}; @@ -79,6 +79,10 @@ struct LocalWorkerImpl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> { // always be zero if there are no actions running and no actions being waited // on by the scheduler. actions_in_transit: Arc, + // The number of actions a worker has completed. + actions_completed: Arc, + // The number of actions a worker has been assigned. + actions_assigned: Arc, metrics: Arc, } @@ -122,10 +126,18 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, running_actions_manager: Arc, metrics: Arc, ) -> Self { + if let Some(actions_before_termination) = config.actions_before_termination { + assert!( + actions_before_termination > 0, + "LocalWorkerImpl::new() - LocalWorkerConfig.actions_before_termination must be greater than 0" + ) + } Self { config, grpc_client, worker_id, + actions_completed: Arc::new(AtomicU64::new(0)), + actions_assigned: Arc::new(AtomicU64::new(0)), running_actions_manager, // Number of actions that have been received in `Update::StartAction`, but // not yet processed by running_actions_manager's spawn. This number should @@ -226,6 +238,9 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, let worker_id_clone = worker_id.clone(); let precondition_script_cfg = self.config.experimental_precondition_script.clone(); let actions_in_transit = self.actions_in_transit.clone(); + + let actions_completed = self.actions_completed.clone(); + let start_action_fut = self.metrics.clone().wrap(move |metrics| async move { metrics.preconditions.wrap(preconditions_met(precondition_script_cfg)) .and_then(|_| running_actions_manager.create_and_add_action(worker_id_clone, start_execute)) @@ -244,6 +259,8 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, .and_then(RunningAction::get_finished_result) // Note: We need ensure we run cleanup even if one of the other steps fail. .then(|result| async move { + + actions_completed.fetch_add(1, Ordering::Release); if let Err(e) = action.cleanup().await { return Result::::Err(e).merge(result); } @@ -267,7 +284,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, let action_stage = ActionStage::Completed(action_result); grpc_client.execution_response( ExecuteResult{ - worker_id, + worker_id: worker_id.clone(), instance_name, action_digest, salt, @@ -291,6 +308,19 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, }; self.actions_in_transit.fetch_add(1, Ordering::Release); + self.actions_assigned.fetch_add(1, Ordering::Release); + let execution_limit = &self.config.actions_before_termination; + if Some(self.actions_assigned.load(Ordering::Acquire)) == *execution_limit { + + let mut grpc_client = self.grpc_client.clone(); + grpc_client.going_away( + GoingAwayRequest { + worker_id: self.worker_id.clone(), + } + ) + .await + .err_tip(|| "Error while calling execution_response")?; + } futures.push( tokio::spawn(start_action_fut).map(move |res| { let res = res.err_tip(|| "Failed to launch spawn")?; @@ -300,6 +330,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, add_future_channel .send(make_publish_future(res).boxed()) .map_err(|_| make_err!(Code::Internal, "LocalWorker could not send future"))?; + Ok(()) }) .boxed() @@ -313,6 +344,18 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, }, res = futures.next() => res.err_tip(|| "Keep-alive should always pending. Likely unable to send data to scheduler")??, }; + if let Some(limit) = self.config.actions_before_termination { + let completed = self.actions_completed.load(Ordering::Acquire); + let assigned = self.actions_assigned.load(Ordering::Acquire); + // Futures will never be empty due to keep alive + if futures.len() == 1 && assigned == limit && completed == assigned { + dbg!(format!( + "Worker with id {} reached max executions - terminating", + self.worker_id + )); + std::process::exit(0); + } + } } // Unreachable. } diff --git a/nativelink-worker/tests/local_worker_test.rs b/nativelink-worker/tests/local_worker_test.rs index cf45ef2c3..07ca05cf1 100644 --- a/nativelink-worker/tests/local_worker_test.rs +++ b/nativelink-worker/tests/local_worker_test.rs @@ -29,7 +29,7 @@ mod utils { pub(crate) mod mock_running_actions_manager; } -use nativelink_config::cas_server::{LocalWorkerConfig, WorkerProperty}; +use nativelink_config::cas_server::{EndpointConfig, LocalWorkerConfig, WorkerProperty}; use nativelink_error::{make_err, make_input_err, Code, Error}; use nativelink_proto::build::bazel::remote::execution::v2::platform::Property; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_worker::Update; @@ -638,4 +638,152 @@ mod local_worker_tests { Ok(()) } + + #[tokio::test] + async fn worker_action_limit_test() -> Result<(), Box> { + const SALT_1: u64 = 1000; + + const ARBITRARY_LARGE_TIMEOUT: f32 = 10000.; + let local_worker_config = LocalWorkerConfig { + platform_properties: HashMap::new(), + worker_api_endpoint: EndpointConfig { + timeout: Some(ARBITRARY_LARGE_TIMEOUT), + ..Default::default() + }, + actions_before_termination: Some(1), + ..Default::default() + }; + let mut test_context = setup_local_worker_with_config(local_worker_config).await; + let streaming_response = test_context.maybe_streaming_response.take().unwrap(); + + { + // Ensure our worker connects and properties were sent. + let props = test_context + .client + .expect_connect_worker(Ok(streaming_response)) + .await; + assert_eq!(props, SupportedProperties::default()); + } + + let expected_worker_id = "foobar".to_string(); + + let mut tx_stream = test_context.maybe_tx_stream.take().unwrap(); + { + // First initialize our worker by sending the response to the connection request. + tx_stream + .send_data(encode_stream_proto(&UpdateForWorker { + update: Some(Update::ConnectionResult(ConnectionResult { + worker_id: expected_worker_id.clone(), + })), + })?) + .await + .map_err(|e| make_input_err!("Could not send : {:?}", e))?; + } + + let action_digest_1 = DigestInfo::new([3u8; 32], 10); + let action_info_1 = ActionInfo { + command_digest: DigestInfo::new([1u8; 32], 10), + input_root_digest: DigestInfo::new([2u8; 32], 10), + timeout: Duration::from_secs(1), + platform_properties: PlatformProperties::default(), + priority: 0, + load_timestamp: SystemTime::UNIX_EPOCH, + insert_timestamp: SystemTime::UNIX_EPOCH, + unique_qualifier: ActionInfoHashKey { + instance_name: INSTANCE_NAME.to_string(), + digest: action_digest_1, + salt: SALT_1, + }, + skip_cache_lookup: true, + digest_function: DigestHasherFunc::Sha256, + }; + + { + // Send execution request. + tx_stream + .send_data(encode_stream_proto(&UpdateForWorker { + update: Some(Update::StartAction(StartExecute { + execute_request: Some(action_info_1.into()), + salt: SALT_1, + queued_timestamp: None, + })), + })?) + .await + .map_err(|e| make_input_err!("Could not send : {:?}", e))?; + } + + // going_away should trigger once execution_limit actions have been assigned + test_context + .client + .expect_going_away(Ok(Response::new(()))) + .await; + + let action_result = ActionResult { + output_files: vec![], + output_folders: vec![], + output_file_symlinks: vec![], + output_directory_symlinks: vec![], + exit_code: 5, + stdout_digest: DigestInfo::new([21u8; 32], 10), + stderr_digest: DigestInfo::new([22u8; 32], 10), + execution_metadata: ExecutionMetadata { + worker: expected_worker_id.clone(), + queued_timestamp: SystemTime::UNIX_EPOCH, + worker_start_timestamp: SystemTime::UNIX_EPOCH, + worker_completed_timestamp: SystemTime::UNIX_EPOCH, + input_fetch_start_timestamp: SystemTime::UNIX_EPOCH, + input_fetch_completed_timestamp: SystemTime::UNIX_EPOCH, + execution_start_timestamp: SystemTime::UNIX_EPOCH, + execution_completed_timestamp: SystemTime::UNIX_EPOCH, + output_upload_start_timestamp: SystemTime::UNIX_EPOCH, + output_upload_completed_timestamp: SystemTime::UNIX_EPOCH, + }, + server_logs: HashMap::new(), + error: None, + message: String::new(), + }; + let running_action = Arc::new(MockRunningAction::new()); + + // Send and wait for response from create_and_add_action to RunningActionsManager. + test_context + .actions_manager + .expect_create_and_add_action(Ok(running_action.clone())) + .await; + + // Now the RunningAction needs to send a series of state updates. This shortcuts them + // into a single call (shortcut for prepare, execute, upload, collect_results, cleanup). + running_action + .simple_expect_get_finished_result(Ok(action_result.clone())) + .await?; + + // Expect the action to be updated in the action cache. + let (stored_digest, stored_result, digest_hasher) = test_context + .actions_manager + .expect_cache_action_result() + .await; + assert_eq!(stored_digest, action_digest_1); + assert_eq!(stored_result, action_result.clone()); + assert_eq!(digest_hasher, DigestHasherFunc::Sha256); + + // Now our client should be notified that our runner finished. + let execution_response = test_context + .client + .expect_execution_response(Ok(Response::new(()))) + .await; + + // Now ensure the final results match our expectations. + assert_eq!( + execution_response, + ExecuteResult { + worker_id: expected_worker_id.clone(), + instance_name: INSTANCE_NAME.to_string(), + action_digest: Some(action_digest_1.into()), + salt: SALT_1, + result: Some(execute_result::Result::ExecuteResponse( + ActionStage::Completed(action_result).into() + )), + } + ); + Ok(()) + } } diff --git a/nativelink-worker/tests/utils/local_worker_test_utils.rs b/nativelink-worker/tests/utils/local_worker_test_utils.rs index 07d1fda29..7dca2217a 100644 --- a/nativelink-worker/tests/utils/local_worker_test_utils.rs +++ b/nativelink-worker/tests/utils/local_worker_test_utils.rs @@ -43,12 +43,14 @@ use super::mock_running_actions_manager::MockRunningActionsManager; enum WorkerClientApiCalls { ConnectWorker(SupportedProperties), ExecutionResponse(ExecuteResult), + GoingAwayRequest(GoingAwayRequest), } #[derive(Debug)] enum WorkerClientApiReturns { ConnectWorker(Result>, Status>), ExecutionResponse(Result, Status>), + GoingAwayRequest(Result, Status>), } #[derive(Clone)] @@ -79,6 +81,29 @@ impl Default for MockWorkerApiClient { } impl MockWorkerApiClient { + pub async fn expect_going_away( + &mut self, + result: Result, Status>, + ) -> GoingAwayRequest { + let mut rx_call_lock = self.rx_call.lock().await; + let req = match rx_call_lock + .recv() + .await + .expect("Could not receive msg in mpsc") + { + WorkerClientApiCalls::GoingAwayRequest(req) => req, + req @ WorkerClientApiCalls::ConnectWorker(_) => { + panic!("expect_going_away expected GoingAwayRequest, got : {req:?}") + } + req @ WorkerClientApiCalls::ExecutionResponse(_) => { + panic!("expect_going_away expected GoingAwayRequest, got : {req:?}") + } + }; + self.tx_resp + .send(WorkerClientApiReturns::GoingAwayRequest(result)) + .expect("Could not send request to mpsc"); + req + } pub async fn expect_connect_worker( &mut self, result: Result>, Status>, @@ -93,6 +118,9 @@ impl MockWorkerApiClient { req @ WorkerClientApiCalls::ExecutionResponse(_) => { panic!("expect_connect_worker expected ConnectWorker, got : {req:?}") } + req @ WorkerClientApiCalls::GoingAwayRequest(_) => { + panic!("expect_connect_worker expected ConnectWorker, got : {req:?}") + } }; self.tx_resp .send(WorkerClientApiReturns::ConnectWorker(result)) @@ -114,6 +142,9 @@ impl MockWorkerApiClient { req @ WorkerClientApiCalls::ConnectWorker(_) => { panic!("expect_execution_response expected ExecutionResponse, got : {req:?}") } + req @ WorkerClientApiCalls::GoingAwayRequest(_) => { + panic!("expect_execution_response expected ExecutionResponse, got : {req:?}") + } }; self.tx_resp .send(WorkerClientApiReturns::ExecutionResponse(result)) @@ -141,6 +172,9 @@ impl WorkerApiClientTrait for MockWorkerApiClient { resp @ WorkerClientApiReturns::ExecutionResponse(_) => { panic!("connect_worker expected ConnectWorker response, received {resp:?}") } + resp @ WorkerClientApiReturns::GoingAwayRequest(_) => { + panic!("connect_worker expected ConnectWorker response, received {resp:?}") + } } } @@ -148,8 +182,24 @@ impl WorkerApiClientTrait for MockWorkerApiClient { unreachable!(); } - async fn going_away(&mut self, _request: GoingAwayRequest) -> Result, Status> { - unreachable!(); + async fn going_away(&mut self, request: GoingAwayRequest) -> Result, Status> { + self.tx_call + .send(WorkerClientApiCalls::GoingAwayRequest(request)) + .expect("Could not send request to mpsc"); + let mut rx_resp_lock = self.rx_resp.lock().await; + match rx_resp_lock + .recv() + .await + .expect("Could not receive msg in mpsc") + { + WorkerClientApiReturns::GoingAwayRequest(result) => result, + resp @ WorkerClientApiReturns::ConnectWorker(_) => { + panic!("going_away_response expected GoingAwayRequest response, received {resp:?}") + } + resp @ WorkerClientApiReturns::ExecutionResponse(_) => { + panic!("going_away_response expected GoingAwayRequest response, received {resp:?}") + } + } } async fn execution_response(&mut self, request: ExecuteResult) -> Result, Status> { @@ -166,6 +216,9 @@ impl WorkerApiClientTrait for MockWorkerApiClient { resp @ WorkerClientApiReturns::ConnectWorker(_) => { panic!("execution_response expected ExecutionResponse response, received {resp:?}") } + resp @ WorkerClientApiReturns::GoingAwayRequest(_) => { + panic!("execution_response expected ExecutionResponse response, received {resp:?}") + } } } }