Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kill worker after action execution limit is reached #825

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions nativelink-config/src/cas_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,15 @@ pub struct LocalWorkerConfig {
/// of the environment variable being the value of the property of the
/// action being executed of that name or the fixed value.
pub additional_environment: Option<HashMap<String, EnvironmentSource>>,

/// The maximum number of actions a worker can executions .
/// After this limit is reached, the nativelink binary will exit.
///
/// For Example:
/// If you would like for each individual action to spin up a kubernetes
/// pod and then exit on completion, you would set this value to 1.
/// Default: None (no limit)
pub actions_before_termination: Option<u64>,
}

#[allow(non_camel_case_types)]
Expand Down
47 changes: 45 additions & 2 deletions nativelink-worker/src/local_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt};
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_worker::Update;
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::worker_api_client::WorkerApiClient;
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
execute_result, ExecuteResult, KeepAliveRequest, UpdateForWorker,
execute_result, ExecuteResult, GoingAwayRequest, KeepAliveRequest, UpdateForWorker,
};
use nativelink_store::fast_slow_store::FastSlowStore;
use nativelink_util::action_messages::{ActionResult, ActionStage};
Expand Down Expand Up @@ -79,6 +79,10 @@ struct LocalWorkerImpl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> {
// always be zero if there are no actions running and no actions being waited
// on by the scheduler.
actions_in_transit: Arc<AtomicU64>,
// The number of actions a worker has completed.
actions_completed: Arc<AtomicU64>,
// The number of actions a worker has been assigned.
actions_assigned: Arc<AtomicU64>,
metrics: Arc<Metrics>,
}

Expand Down Expand Up @@ -122,10 +126,18 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
running_actions_manager: Arc<U>,
metrics: Arc<Metrics>,
) -> Self {
if let Some(actions_before_termination) = config.actions_before_termination {
assert!(
actions_before_termination > 0,
"LocalWorkerImpl::new() - LocalWorkerConfig.actions_before_termination must be greater than 0"
)
}
Self {
config,
grpc_client,
worker_id,
actions_completed: Arc::new(AtomicU64::new(0)),
actions_assigned: Arc::new(AtomicU64::new(0)),
running_actions_manager,
// Number of actions that have been received in `Update::StartAction`, but
// not yet processed by running_actions_manager's spawn. This number should
Expand Down Expand Up @@ -226,6 +238,9 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
let worker_id_clone = worker_id.clone();
let precondition_script_cfg = self.config.experimental_precondition_script.clone();
let actions_in_transit = self.actions_in_transit.clone();

let actions_completed = self.actions_completed.clone();

let start_action_fut = self.metrics.clone().wrap(move |metrics| async move {
metrics.preconditions.wrap(preconditions_met(precondition_script_cfg))
.and_then(|_| running_actions_manager.create_and_add_action(worker_id_clone, start_execute))
Expand All @@ -244,6 +259,8 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
.and_then(RunningAction::get_finished_result)
// Note: We need ensure we run cleanup even if one of the other steps fail.
.then(|result| async move {

actions_completed.fetch_add(1, Ordering::Release);
if let Err(e) = action.cleanup().await {
return Result::<ActionResult, Error>::Err(e).merge(result);
}
Expand All @@ -267,7 +284,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
let action_stage = ActionStage::Completed(action_result);
grpc_client.execution_response(
ExecuteResult{
worker_id,
worker_id: worker_id.clone(),
instance_name,
action_digest,
salt,
Expand All @@ -291,6 +308,19 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
};

self.actions_in_transit.fetch_add(1, Ordering::Release);
self.actions_assigned.fetch_add(1, Ordering::Release);
let execution_limit = &self.config.actions_before_termination;
if Some(self.actions_assigned.load(Ordering::Acquire)) == *execution_limit {

let mut grpc_client = self.grpc_client.clone();
grpc_client.going_away(
GoingAwayRequest {
worker_id: self.worker_id.clone(),
}
)
.await
.err_tip(|| "Error while calling execution_response")?;
}
futures.push(
tokio::spawn(start_action_fut).map(move |res| {
let res = res.err_tip(|| "Failed to launch spawn")?;
Expand All @@ -300,6 +330,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
add_future_channel
.send(make_publish_future(res).boxed())
.map_err(|_| make_err!(Code::Internal, "LocalWorker could not send future"))?;

Ok(())
})
.boxed()
Expand All @@ -313,6 +344,18 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
},
res = futures.next() => res.err_tip(|| "Keep-alive should always pending. Likely unable to send data to scheduler")??,
};
if let Some(limit) = self.config.actions_before_termination {
let completed = self.actions_completed.load(Ordering::Acquire);
let assigned = self.actions_assigned.load(Ordering::Acquire);
// Futures will never be empty due to keep alive
if futures.len() == 1 && assigned == limit && completed == assigned {
dbg!(format!(
"Worker with id {} reached max executions - terminating",
self.worker_id
));
std::process::exit(0);
}
}
}
// Unreachable.
}
Expand Down
150 changes: 149 additions & 1 deletion nativelink-worker/tests/local_worker_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ mod utils {
pub(crate) mod mock_running_actions_manager;
}

use nativelink_config::cas_server::{LocalWorkerConfig, WorkerProperty};
use nativelink_config::cas_server::{EndpointConfig, LocalWorkerConfig, WorkerProperty};
use nativelink_error::{make_err, make_input_err, Code, Error};
use nativelink_proto::build::bazel::remote::execution::v2::platform::Property;
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_worker::Update;
Expand Down Expand Up @@ -638,4 +638,152 @@ mod local_worker_tests {

Ok(())
}

#[tokio::test]
async fn worker_action_limit_test() -> Result<(), Box<dyn std::error::Error>> {
const SALT_1: u64 = 1000;

const ARBITRARY_LARGE_TIMEOUT: f32 = 10000.;
let local_worker_config = LocalWorkerConfig {
platform_properties: HashMap::new(),
worker_api_endpoint: EndpointConfig {
timeout: Some(ARBITRARY_LARGE_TIMEOUT),
..Default::default()
},
actions_before_termination: Some(1),
..Default::default()
};
let mut test_context = setup_local_worker_with_config(local_worker_config).await;
let streaming_response = test_context.maybe_streaming_response.take().unwrap();

{
// Ensure our worker connects and properties were sent.
let props = test_context
.client
.expect_connect_worker(Ok(streaming_response))
.await;
assert_eq!(props, SupportedProperties::default());
}

let expected_worker_id = "foobar".to_string();

let mut tx_stream = test_context.maybe_tx_stream.take().unwrap();
{
// First initialize our worker by sending the response to the connection request.
tx_stream
.send_data(encode_stream_proto(&UpdateForWorker {
update: Some(Update::ConnectionResult(ConnectionResult {
worker_id: expected_worker_id.clone(),
})),
})?)
.await
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
}

let action_digest_1 = DigestInfo::new([3u8; 32], 10);
let action_info_1 = ActionInfo {
command_digest: DigestInfo::new([1u8; 32], 10),
input_root_digest: DigestInfo::new([2u8; 32], 10),
timeout: Duration::from_secs(1),
platform_properties: PlatformProperties::default(),
priority: 0,
load_timestamp: SystemTime::UNIX_EPOCH,
insert_timestamp: SystemTime::UNIX_EPOCH,
unique_qualifier: ActionInfoHashKey {
instance_name: INSTANCE_NAME.to_string(),
digest: action_digest_1,
salt: SALT_1,
},
skip_cache_lookup: true,
digest_function: DigestHasherFunc::Sha256,
};

{
// Send execution request.
tx_stream
.send_data(encode_stream_proto(&UpdateForWorker {
update: Some(Update::StartAction(StartExecute {
execute_request: Some(action_info_1.into()),
salt: SALT_1,
queued_timestamp: None,
})),
})?)
.await
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
}

// going_away should trigger once execution_limit actions have been assigned
test_context
.client
.expect_going_away(Ok(Response::new(())))
.await;

let action_result = ActionResult {
output_files: vec![],
output_folders: vec![],
output_file_symlinks: vec![],
output_directory_symlinks: vec![],
exit_code: 5,
stdout_digest: DigestInfo::new([21u8; 32], 10),
stderr_digest: DigestInfo::new([22u8; 32], 10),
execution_metadata: ExecutionMetadata {
worker: expected_worker_id.clone(),
queued_timestamp: SystemTime::UNIX_EPOCH,
worker_start_timestamp: SystemTime::UNIX_EPOCH,
worker_completed_timestamp: SystemTime::UNIX_EPOCH,
input_fetch_start_timestamp: SystemTime::UNIX_EPOCH,
input_fetch_completed_timestamp: SystemTime::UNIX_EPOCH,
execution_start_timestamp: SystemTime::UNIX_EPOCH,
execution_completed_timestamp: SystemTime::UNIX_EPOCH,
output_upload_start_timestamp: SystemTime::UNIX_EPOCH,
output_upload_completed_timestamp: SystemTime::UNIX_EPOCH,
},
server_logs: HashMap::new(),
error: None,
message: String::new(),
};
let running_action = Arc::new(MockRunningAction::new());

// Send and wait for response from create_and_add_action to RunningActionsManager.
test_context
.actions_manager
.expect_create_and_add_action(Ok(running_action.clone()))
.await;

// Now the RunningAction needs to send a series of state updates. This shortcuts them
// into a single call (shortcut for prepare, execute, upload, collect_results, cleanup).
running_action
.simple_expect_get_finished_result(Ok(action_result.clone()))
.await?;

// Expect the action to be updated in the action cache.
let (stored_digest, stored_result, digest_hasher) = test_context
.actions_manager
.expect_cache_action_result()
.await;
assert_eq!(stored_digest, action_digest_1);
assert_eq!(stored_result, action_result.clone());
assert_eq!(digest_hasher, DigestHasherFunc::Sha256);

// Now our client should be notified that our runner finished.
let execution_response = test_context
.client
.expect_execution_response(Ok(Response::new(())))
.await;

// Now ensure the final results match our expectations.
assert_eq!(
execution_response,
ExecuteResult {
worker_id: expected_worker_id.clone(),
instance_name: INSTANCE_NAME.to_string(),
action_digest: Some(action_digest_1.into()),
salt: SALT_1,
result: Some(execute_result::Result::ExecuteResponse(
ActionStage::Completed(action_result).into()
)),
}
);
Ok(())
}
}
Loading