Skip to content

Commit

Permalink
Kill worker after action execution limit is reached
Browse files Browse the repository at this point in the history
Allows users to set the maximum number of action executions a worker is
allowed to complete. Upon reaching this limit, the worker will no longer
accept new jobs, and will exit upon completing all assigned ones.

closes: #815
  • Loading branch information
Zach Birenbaum committed Apr 2, 2024
1 parent c60fb55 commit d295c88
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 5 deletions.
9 changes: 9 additions & 0 deletions nativelink-config/src/cas_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,15 @@ pub struct LocalWorkerConfig {
/// of the environment variable being the value of the property of the
/// action being executed of that name or the fixed value.
pub additional_environment: Option<HashMap<String, EnvironmentSource>>,

/// The maximum number of action executions a worker can complete.
/// After this limit is reached, the nativelink binary will exit.
/// A value of None means no limit.
///
/// For Example:
/// If you would like for each individual action to spin up a kubernetes
/// pod and then exit on completion, you would set this value to 1.
pub execution_limit: Option<u64>,
}

#[allow(non_camel_case_types)]
Expand Down
37 changes: 35 additions & 2 deletions nativelink-worker/src/local_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt};
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_worker::Update;
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::worker_api_client::WorkerApiClient;
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
execute_result, ExecuteResult, KeepAliveRequest, UpdateForWorker,
execute_result, ExecuteResult, GoingAwayRequest, KeepAliveRequest, UpdateForWorker,
};
use nativelink_store::fast_slow_store::FastSlowStore;
use nativelink_util::action_messages::{ActionResult, ActionStage};
Expand Down Expand Up @@ -79,6 +79,10 @@ struct LocalWorkerImpl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> {
// always be zero if there are no actions running and no actions being waited
// on by the scheduler.
actions_in_transit: Arc<AtomicU64>,
// The number of actions a worker has completed.
actions_completed: Arc<AtomicU64>,
// The number of actions a worker has been assigned.
actions_assigned: Arc<AtomicU64>,
metrics: Arc<Metrics>,
}

Expand Down Expand Up @@ -126,6 +130,8 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
config,
grpc_client,
worker_id,
actions_completed: Arc::new(AtomicU64::new(0)),
actions_assigned: Arc::new(AtomicU64::new(0)),
running_actions_manager,
// Number of actions that have been received in `Update::StartAction`, but
// not yet processed by running_actions_manager's spawn. This number should
Expand Down Expand Up @@ -226,6 +232,9 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
let worker_id_clone = worker_id.clone();
let precondition_script_cfg = self.config.experimental_precondition_script.clone();
let actions_in_transit = self.actions_in_transit.clone();

let actions_completed = self.actions_completed.clone();

let start_action_fut = self.metrics.clone().wrap(move |metrics| async move {
metrics.preconditions.wrap(preconditions_met(precondition_script_cfg))
.and_then(|_| running_actions_manager.create_and_add_action(worker_id_clone, start_execute))
Expand All @@ -244,6 +253,8 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
.and_then(RunningAction::get_finished_result)
// Note: We need ensure we run cleanup even if one of the other steps fail.
.then(|result| async move {

actions_completed.fetch_add(1, Ordering::Release);
if let Err(e) = action.cleanup().await {
return Result::<ActionResult, Error>::Err(e).merge(result);
}
Expand All @@ -267,7 +278,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
let action_stage = ActionStage::Completed(action_result);
grpc_client.execution_response(
ExecuteResult{
worker_id,
worker_id: worker_id.clone(),
instance_name,
action_digest,
salt,
Expand All @@ -291,6 +302,19 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
};

self.actions_in_transit.fetch_add(1, Ordering::Release);
self.actions_assigned.fetch_add(1, Ordering::Release);
let execution_limit = &self.config.execution_limit;
if Some(self.actions_assigned.load(Ordering::Acquire)) == *execution_limit {

let mut grpc_client = self.grpc_client.clone();
grpc_client.going_away(
GoingAwayRequest {
worker_id: self.worker_id.clone(),
}
)
.await
.err_tip(|| "Error while calling execution_response")?;
}
futures.push(
tokio::spawn(start_action_fut).map(move |res| {
let res = res.err_tip(|| "Failed to launch spawn")?;
Expand All @@ -300,6 +324,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
add_future_channel
.send(make_publish_future(res).boxed())
.map_err(|_| make_err!(Code::Internal, "LocalWorker could not send future"))?;

Ok(())
})
.boxed()
Expand All @@ -313,6 +338,14 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
},
res = futures.next() => res.err_tip(|| "Keep-alive should always pending. Likely unable to send data to scheduler")??,
};
if let Some(limit) = self.config.execution_limit {
let completed = self.actions_completed.load(Ordering::Acquire);
let assigned = self.actions_assigned.load(Ordering::Acquire);
// Futures will never be empty due to keep alive
if futures.len() == 1 && assigned == limit && completed == assigned {
std::process::exit(0);
}
}
}
// Unreachable.
}
Expand Down
150 changes: 149 additions & 1 deletion nativelink-worker/tests/local_worker_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ mod utils {
pub(crate) mod mock_running_actions_manager;
}

use nativelink_config::cas_server::{LocalWorkerConfig, WorkerProperty};
use nativelink_config::cas_server::{EndpointConfig, LocalWorkerConfig, WorkerProperty};
use nativelink_error::{make_err, make_input_err, Code, Error};
use nativelink_proto::build::bazel::remote::execution::v2::platform::Property;
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_worker::Update;
Expand Down Expand Up @@ -638,4 +638,152 @@ mod local_worker_tests {

Ok(())
}

#[tokio::test]
async fn worker_action_limit_test() -> Result<(), Box<dyn std::error::Error>> {
const SALT_1: u64 = 1000;

const ARBITRARY_LARGE_TIMEOUT: f32 = 10000.;
let local_worker_config = LocalWorkerConfig {
platform_properties: HashMap::new(),
worker_api_endpoint: EndpointConfig {
timeout: Some(ARBITRARY_LARGE_TIMEOUT),
..Default::default()
},
execution_limit: Some(1),
..Default::default()
};
let mut test_context = setup_local_worker_with_config(local_worker_config).await;
let streaming_response = test_context.maybe_streaming_response.take().unwrap();

{
// Ensure our worker connects and properties were sent.
let props = test_context
.client
.expect_connect_worker(Ok(streaming_response))
.await;
assert_eq!(props, SupportedProperties::default());
}

let expected_worker_id = "foobar".to_string();

let mut tx_stream = test_context.maybe_tx_stream.take().unwrap();
{
// First initialize our worker by sending the response to the connection request.
tx_stream
.send_data(encode_stream_proto(&UpdateForWorker {
update: Some(Update::ConnectionResult(ConnectionResult {
worker_id: expected_worker_id.clone(),
})),
})?)
.await
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
}

let action_digest_1 = DigestInfo::new([3u8; 32], 10);
let action_info_1 = ActionInfo {
command_digest: DigestInfo::new([1u8; 32], 10),
input_root_digest: DigestInfo::new([2u8; 32], 10),
timeout: Duration::from_secs(1),
platform_properties: PlatformProperties::default(),
priority: 0,
load_timestamp: SystemTime::UNIX_EPOCH,
insert_timestamp: SystemTime::UNIX_EPOCH,
unique_qualifier: ActionInfoHashKey {
instance_name: INSTANCE_NAME.to_string(),
digest: action_digest_1,
salt: SALT_1,
},
skip_cache_lookup: true,
digest_function: DigestHasherFunc::Sha256,
};

{
// Send execution request.
tx_stream
.send_data(encode_stream_proto(&UpdateForWorker {
update: Some(Update::StartAction(StartExecute {
execute_request: Some(action_info_1.into()),
salt: SALT_1,
queued_timestamp: None,
})),
})?)
.await
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
}

// going_away should trigger once execution_limit actions have been assigned
test_context
.client
.expect_going_away(Ok(Response::new(())))
.await;

let action_result = ActionResult {
output_files: vec![],
output_folders: vec![],
output_file_symlinks: vec![],
output_directory_symlinks: vec![],
exit_code: 5,
stdout_digest: DigestInfo::new([21u8; 32], 10),
stderr_digest: DigestInfo::new([22u8; 32], 10),
execution_metadata: ExecutionMetadata {
worker: expected_worker_id.clone(),
queued_timestamp: SystemTime::UNIX_EPOCH,
worker_start_timestamp: SystemTime::UNIX_EPOCH,
worker_completed_timestamp: SystemTime::UNIX_EPOCH,
input_fetch_start_timestamp: SystemTime::UNIX_EPOCH,
input_fetch_completed_timestamp: SystemTime::UNIX_EPOCH,
execution_start_timestamp: SystemTime::UNIX_EPOCH,
execution_completed_timestamp: SystemTime::UNIX_EPOCH,
output_upload_start_timestamp: SystemTime::UNIX_EPOCH,
output_upload_completed_timestamp: SystemTime::UNIX_EPOCH,
},
server_logs: HashMap::new(),
error: None,
message: String::new(),
};
let running_action = Arc::new(MockRunningAction::new());

// Send and wait for response from create_and_add_action to RunningActionsManager.
test_context
.actions_manager
.expect_create_and_add_action(Ok(running_action.clone()))
.await;

// Now the RunningAction needs to send a series of state updates. This shortcuts them
// into a single call (shortcut for prepare, execute, upload, collect_results, cleanup).
running_action
.simple_expect_get_finished_result(Ok(action_result.clone()))
.await?;

// Expect the action to be updated in the action cache.
let (stored_digest, stored_result, digest_hasher) = test_context
.actions_manager
.expect_cache_action_result()
.await;
assert_eq!(stored_digest, action_digest_1);
assert_eq!(stored_result, action_result.clone());
assert_eq!(digest_hasher, DigestHasherFunc::Sha256);

// Now our client should be notified that our runner finished.
let execution_response = test_context
.client
.expect_execution_response(Ok(Response::new(())))
.await;

// Now ensure the final results match our expectations.
assert_eq!(
execution_response,
ExecuteResult {
worker_id: expected_worker_id.clone(),
instance_name: INSTANCE_NAME.to_string(),
action_digest: Some(action_digest_1.into()),
salt: SALT_1,
result: Some(execute_result::Result::ExecuteResponse(
ActionStage::Completed(action_result).into()
)),
}
);
Ok(())
}
}
57 changes: 55 additions & 2 deletions nativelink-worker/tests/utils/local_worker_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ use super::mock_running_actions_manager::MockRunningActionsManager;
enum WorkerClientApiCalls {
ConnectWorker(SupportedProperties),
ExecutionResponse(ExecuteResult),
GoingAwayRequest(GoingAwayRequest),
}

#[derive(Debug)]
enum WorkerClientApiReturns {
ConnectWorker(Result<Response<Streaming<UpdateForWorker>>, Status>),
ExecutionResponse(Result<Response<()>, Status>),
GoingAwayRequest(Result<Response<()>, Status>),
}

#[derive(Clone)]
Expand Down Expand Up @@ -79,6 +81,29 @@ impl Default for MockWorkerApiClient {
}

impl MockWorkerApiClient {
pub async fn expect_going_away(
&mut self,
result: Result<Response<()>, Status>,
) -> GoingAwayRequest {
let mut rx_call_lock = self.rx_call.lock().await;
let req = match rx_call_lock
.recv()
.await
.expect("Could not receive msg in mpsc")
{
WorkerClientApiCalls::GoingAwayRequest(req) => req,
req @ WorkerClientApiCalls::ConnectWorker(_) => {
panic!("expect_going_away expected GoingAwayRequest, got : {req:?}")
}
req @ WorkerClientApiCalls::ExecutionResponse(_) => {
panic!("expect_going_away expected GoingAwayRequest, got : {req:?}")
}
};
self.tx_resp
.send(WorkerClientApiReturns::GoingAwayRequest(result))
.expect("Could not send request to mpsc");
req
}
pub async fn expect_connect_worker(
&mut self,
result: Result<Response<Streaming<UpdateForWorker>>, Status>,
Expand All @@ -93,6 +118,9 @@ impl MockWorkerApiClient {
req @ WorkerClientApiCalls::ExecutionResponse(_) => {
panic!("expect_connect_worker expected ConnectWorker, got : {req:?}")
}
req @ WorkerClientApiCalls::GoingAwayRequest(_) => {
panic!("expect_connect_worker expected ConnectWorker, got : {req:?}")
}
};
self.tx_resp
.send(WorkerClientApiReturns::ConnectWorker(result))
Expand All @@ -114,6 +142,9 @@ impl MockWorkerApiClient {
req @ WorkerClientApiCalls::ConnectWorker(_) => {
panic!("expect_execution_response expected ExecutionResponse, got : {req:?}")
}
req @ WorkerClientApiCalls::GoingAwayRequest(_) => {
panic!("expect_execution_response expected ExecutionResponse, got : {req:?}")
}
};
self.tx_resp
.send(WorkerClientApiReturns::ExecutionResponse(result))
Expand Down Expand Up @@ -141,15 +172,34 @@ impl WorkerApiClientTrait for MockWorkerApiClient {
resp @ WorkerClientApiReturns::ExecutionResponse(_) => {
panic!("connect_worker expected ConnectWorker response, received {resp:?}")
}
resp @ WorkerClientApiReturns::GoingAwayRequest(_) => {
panic!("connect_worker expected ConnectWorker response, received {resp:?}")
}
}
}

async fn keep_alive(&mut self, _request: KeepAliveRequest) -> Result<Response<()>, Status> {
unreachable!();
}

async fn going_away(&mut self, _request: GoingAwayRequest) -> Result<Response<()>, Status> {
unreachable!();
async fn going_away(&mut self, request: GoingAwayRequest) -> Result<Response<()>, Status> {
self.tx_call
.send(WorkerClientApiCalls::GoingAwayRequest(request))
.expect("Could not send request to mpsc");
let mut rx_resp_lock = self.rx_resp.lock().await;
match rx_resp_lock
.recv()
.await
.expect("Could not receive msg in mpsc")
{
WorkerClientApiReturns::GoingAwayRequest(result) => result,
resp @ WorkerClientApiReturns::ConnectWorker(_) => {
panic!("going_away_response expected GoingAwayRequest response, received {resp:?}")
}
resp @ WorkerClientApiReturns::ExecutionResponse(_) => {
panic!("going_away_response expected GoingAwayRequest response, received {resp:?}")
}
}
}

async fn execution_response(&mut self, request: ExecuteResult) -> Result<Response<()>, Status> {
Expand All @@ -166,6 +216,9 @@ impl WorkerApiClientTrait for MockWorkerApiClient {
resp @ WorkerClientApiReturns::ConnectWorker(_) => {
panic!("execution_response expected ExecutionResponse response, received {resp:?}")
}
resp @ WorkerClientApiReturns::GoingAwayRequest(_) => {
panic!("execution_response expected ExecutionResponse response, received {resp:?}")
}
}
}
}
Expand Down

0 comments on commit d295c88

Please sign in to comment.