Skip to content

Commit

Permalink
Connection Manager Rewrite
Browse files Browse the repository at this point in the history
The connection manager written in grpc_utils made incorrect assumptions about how
the tonic and tower implementations were written and is not suitable for maintaining
multiple connections and ensuring stability.

Completely rewrite this to manage the tonic::Channel for each tonic::Endpoint itself
to make a simpler external API and ensure that connection errors are handled correctly.
This is performed by using a single worker loop that manages all of the connections
and wrapping each connection to inform state to the worker.
  • Loading branch information
chrisstaite-menlo committed Apr 2, 2024
1 parent 00ff4a0 commit 4ec1cb2
Show file tree
Hide file tree
Showing 9 changed files with 588 additions and 198 deletions.
5 changes: 5 additions & 0 deletions nativelink-config/src/schedulers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ pub struct GrpcScheduler {
/// request is queued.
#[serde(default)]
pub max_concurrent_requests: usize,

/// The number of connections to make to each specified endpoint to balance
/// the load over multiple TCP connections. Default 1.
#[serde(default)]
pub connections_per_endpoint: usize,
}

#[derive(Deserialize, Debug)]
Expand Down
5 changes: 5 additions & 0 deletions nativelink-config/src/stores.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,11 @@ pub struct GrpcStore {
/// request is queued.
#[serde(default)]
pub max_concurrent_requests: usize,

/// The number of connections to make to each specified endpoint to balance
/// the load over multiple TCP connections. Default 1.
#[serde(default)]
pub connections_per_endpoint: usize,
}

/// The possible error codes that might occur on an upstream request.
Expand Down
45 changes: 25 additions & 20 deletions nativelink-scheduler/src/grpc_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use nativelink_proto::google::longrunning::Operation;
use nativelink_util::action_messages::{
ActionInfo, ActionInfoHashKey, ActionState, DEFAULT_EXECUTION_PRIORITY,
};
use nativelink_util::grpc_utils::ConnectionManager;
use nativelink_util::connection_manager::ConnectionManager;
use nativelink_util::retry::{Retrier, RetryResult};
use nativelink_util::tls_utils;
use parking_lot::Mutex;
Expand Down Expand Up @@ -72,16 +72,20 @@ impl GrpcScheduler {
jitter_fn: Box<dyn Fn(Duration) -> Duration + Send + Sync>,
) -> Result<Self, Error> {
let endpoint = tls_utils::endpoint(&config.endpoint)?;
let jitter_fn = Arc::new(jitter_fn);
Ok(Self {
platform_property_managers: Mutex::new(HashMap::new()),
retrier: Retrier::new(
Arc::new(|duration| Box::pin(sleep(duration))),
Arc::new(jitter_fn),
jitter_fn.clone(),
config.retry.to_owned(),
),
connection_manager: ConnectionManager::new(
std::iter::once(endpoint),
config.connections_per_endpoint,
config.max_concurrent_requests,
config.retry.to_owned(),
jitter_fn,
),
})
}
Expand Down Expand Up @@ -164,16 +168,17 @@ impl ActionScheduler for GrpcScheduler {

self.perform_request(instance_name, |instance_name| async move {
// Not in the cache, lookup the capabilities with the upstream.
let (connection, channel) = self.connection_manager.get_connection().await;
let channel = self
.connection_manager
.connection()
.await
.err_tip(|| "in get_platform_property_manager()")?;
let capabilities_result = CapabilitiesClient::new(channel)
.get_capabilities(GetCapabilitiesRequest {
instance_name: instance_name.to_string(),
})
.await
.err_tip(|| "Retrieving upstream GrpcScheduler capabilities");
if let Err(err) = &capabilities_result {
connection.on_error(err);
}
let capabilities = capabilities_result?.into_inner();
let platform_property_manager = Arc::new(PlatformPropertyManager::new(
capabilities
Expand Down Expand Up @@ -220,15 +225,15 @@ impl ActionScheduler for GrpcScheduler {
};
let result_stream = self
.perform_request(request, |request| async move {
let (connection, channel) = self.connection_manager.get_connection().await;
let result = ExecutionClient::new(channel)
let channel = self
.connection_manager
.connection()
.await
.err_tip(|| "in add_action()")?;
ExecutionClient::new(channel)
.execute(Request::new(request))
.await
.err_tip(|| "Sending action to upstream scheduler");
if let Err(err) = &result {
connection.on_error(err);
}
result
.err_tip(|| "Sending action to upstream scheduler")
})
.await?
.into_inner();
Expand All @@ -244,15 +249,15 @@ impl ActionScheduler for GrpcScheduler {
};
let result_stream = self
.perform_request(request, |request| async move {
let (connection, channel) = self.connection_manager.get_connection().await;
let result = ExecutionClient::new(channel)
let channel = self
.connection_manager
.connection()
.await
.err_tip(|| "in find_existing_action()")?;
ExecutionClient::new(channel)
.wait_execution(Request::new(request))
.await
.err_tip(|| "While getting wait_execution stream");
if let Err(err) = &result {
connection.on_error(err);
}
result
.err_tip(|| "While getting wait_execution stream")
})
.and_then(|result_stream| Self::stream_state(result_stream.into_inner()))
.await;
Expand Down
138 changes: 74 additions & 64 deletions nativelink-store/src/grpc_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::time::Duration;
use async_trait::async_trait;
use bytes::BytesMut;
use futures::stream::{unfold, FuturesUnordered};
use futures::{future, Future, Stream, StreamExt, TryStreamExt};
use futures::{future, Future, Stream, StreamExt, TryFutureExt, TryStreamExt};
use nativelink_error::{error_if, make_input_err, Error, ResultExt};
use nativelink_proto::build::bazel::remote::execution::v2::action_cache_client::ActionCacheClient;
use nativelink_proto::build::bazel::remote::execution::v2::content_addressable_storage_client::ContentAddressableStorageClient;
Expand All @@ -36,7 +36,7 @@ use nativelink_proto::google::bytestream::{
};
use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf};
use nativelink_util::common::DigestInfo;
use nativelink_util::grpc_utils::ConnectionManager;
use nativelink_util::connection_manager::ConnectionManager;
use nativelink_util::health_utils::HealthStatusIndicator;
use nativelink_util::proto_stream_utils::{
FirstStream, WriteRequestStreamWrapper, WriteState, WriteStateWrapper,
Expand Down Expand Up @@ -98,17 +98,21 @@ impl GrpcStore {
endpoints.push(endpoint);
}

let jitter_fn = Arc::new(jitter_fn);
Ok(GrpcStore {
instance_name: config.instance_name.clone(),
store_type: config.store_type,
retrier: Retrier::new(
Arc::new(|duration| Box::pin(sleep(duration))),
Arc::new(jitter_fn),
jitter_fn.clone(),
config.retry.to_owned(),
),
connection_manager: ConnectionManager::new(
endpoints.into_iter(),
config.connections_per_endpoint,
config.max_concurrent_requests,
config.retry.to_owned(),
jitter_fn,
),
})
}
Expand Down Expand Up @@ -145,15 +149,15 @@ impl GrpcStore {
let mut request = grpc_request.into_inner();
request.instance_name = self.instance_name.clone();
self.perform_request(request, |request| async move {
let (connection, channel) = self.connection_manager.get_connection().await;
let result = ContentAddressableStorageClient::new(channel)
let channel = self
.connection_manager
.connection()
.await
.err_tip(|| "in find_missing_blobs")?;
ContentAddressableStorageClient::new(channel)
.find_missing_blobs(Request::new(request))
.await
.err_tip(|| "in GrpcStore::find_missing_blobs");
if let Err(err) = &result {
connection.on_error(err);
}
result
.err_tip(|| "in GrpcStore::find_missing_blobs")
})
.await
}
Expand All @@ -170,15 +174,15 @@ impl GrpcStore {
let mut request = grpc_request.into_inner();
request.instance_name = self.instance_name.clone();
self.perform_request(request, |request| async move {
let (connection, channel) = self.connection_manager.get_connection().await;
let result = ContentAddressableStorageClient::new(channel)
let channel = self
.connection_manager
.connection()
.await
.err_tip(|| "in batch_update_blobs")?;
ContentAddressableStorageClient::new(channel)
.batch_update_blobs(Request::new(request))
.await
.err_tip(|| "in GrpcStore::batch_update_blobs");
if let Err(err) = &result {
connection.on_error(err);
}
result
.err_tip(|| "in GrpcStore::batch_update_blobs")
})
.await
}
Expand All @@ -195,15 +199,15 @@ impl GrpcStore {
let mut request = grpc_request.into_inner();
request.instance_name = self.instance_name.clone();
self.perform_request(request, |request| async move {
let (connection, channel) = self.connection_manager.get_connection().await;
let result = ContentAddressableStorageClient::new(channel)
let channel = self
.connection_manager
.connection()
.await
.err_tip(|| "in batch_read_blobs")?;
ContentAddressableStorageClient::new(channel)
.batch_read_blobs(Request::new(request))
.await
.err_tip(|| "in GrpcStore::batch_read_blobs");
if let Err(err) = &result {
connection.on_error(err);
}
result
.err_tip(|| "in GrpcStore::batch_read_blobs")
})
.await
}
Expand All @@ -220,15 +224,15 @@ impl GrpcStore {
let mut request = grpc_request.into_inner();
request.instance_name = self.instance_name.clone();
self.perform_request(request, |request| async move {
let (connection, channel) = self.connection_manager.get_connection().await;
let result = ContentAddressableStorageClient::new(channel)
let channel = self
.connection_manager
.connection()
.await
.err_tip(|| "in get_tree")?;
ContentAddressableStorageClient::new(channel)
.get_tree(Request::new(request))
.await
.err_tip(|| "in GrpcStore::get_tree");
if let Err(err) = &result {
connection.on_error(err);
}
result
.err_tip(|| "in GrpcStore::get_tree")
})
.await
}
Expand All @@ -247,15 +251,16 @@ impl GrpcStore {
&self,
request: ReadRequest,
) -> Result<impl Stream<Item = Result<ReadResponse, Status>>, Error> {
let (connection, channel) = self.connection_manager.get_connection().await;
let result = ByteStreamClient::new(channel)
let channel = self
.connection_manager
.connection()
.await
.err_tip(|| "in read_internal")?;
let mut response = ByteStreamClient::new(channel)
.read(Request::new(request))
.await
.err_tip(|| "in GrpcStore::read");
if let Err(err) = &result {
connection.on_error(err);
}
let mut response = result?.into_inner();
.err_tip(|| "in GrpcStore::read")?
.into_inner();
let first_response = response
.message()
.await
Expand Down Expand Up @@ -300,14 +305,20 @@ impl GrpcStore {
let result = self
.retrier
.retry(unfold(local_state, move |local_state| async move {
let (connection, channel) = self.connection_manager.get_connection().await;
// The client write may occur on a separate thread and
// therefore in order to share the state with it we have to
// wrap it in a Mutex and retrieve it after the write
// has completed. There is no way to get the value back
// from the client.
let result = ByteStreamClient::new(channel)
.write(WriteStateWrapper::new(local_state.clone()))
let result = self
.connection_manager
.connection()
.and_then(|channel| async {
ByteStreamClient::new(channel)
.write(WriteStateWrapper::new(local_state.clone()))
.await
.err_tip(|| "in GrpcStore::write")
})
.await;

// Get the state back from StateWrapper, this should be
Expand All @@ -319,9 +330,8 @@ impl GrpcStore {
RetryResult::Err(err.append("Where read_stream_error was set"))
} else {
// On error determine whether it is possible to retry.
match result.err_tip(|| "in GrpcStore::write") {
match result {
Err(err) => {
connection.on_error(&err);
if local_state_locked.can_resume() {
local_state_locked.resume();
RetryResult::Retry(err)
Expand Down Expand Up @@ -359,15 +369,15 @@ impl GrpcStore {
}

self.perform_request(request, |request| async move {
let (connection, channel) = self.connection_manager.get_connection().await;
let result = ByteStreamClient::new(channel)
let channel = self
.connection_manager
.connection()
.await
.err_tip(|| "in query_write_status")?;
ByteStreamClient::new(channel)
.query_write_status(Request::new(request))
.await
.err_tip(|| "in GrpcStore::query_write_status");
if let Err(err) = &result {
connection.on_error(err);
}
result
.err_tip(|| "in GrpcStore::query_write_status")
})
.await
}
Expand All @@ -379,15 +389,15 @@ impl GrpcStore {
let mut request = grpc_request.into_inner();
request.instance_name = self.instance_name.clone();
self.perform_request(request, |request| async move {
let (connection, channel) = self.connection_manager.get_connection().await;
let result = ActionCacheClient::new(channel)
let channel = self
.connection_manager
.connection()
.await
.err_tip(|| "in get_action_result")?;
ActionCacheClient::new(channel)
.get_action_result(Request::new(request))
.await
.err_tip(|| "in GrpcStore::get_action_result");
if let Err(err) = &result {
connection.on_error(err);
}
result
.err_tip(|| "in GrpcStore::get_action_result")
})
.await
}
Expand All @@ -399,15 +409,15 @@ impl GrpcStore {
let mut request = grpc_request.into_inner();
request.instance_name = self.instance_name.clone();
self.perform_request(request, |request| async move {
let (connection, channel) = self.connection_manager.get_connection().await;
let result = ActionCacheClient::new(channel)
let channel = self
.connection_manager
.connection()
.await
.err_tip(|| "in update_action_result")?;
ActionCacheClient::new(channel)
.update_action_result(Request::new(request))
.await
.err_tip(|| "in GrpcStore::update_action_result");
if let Err(err) = &result {
connection.on_error(err);
}
result
.err_tip(|| "in GrpcStore::update_action_result")
})
.await
}
Expand Down
2 changes: 1 addition & 1 deletion nativelink-util/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ rust_library(
"src/action_messages.rs",
"src/buf_channel.rs",
"src/common.rs",
"src/connection_manager.rs",
"src/digest_hasher.rs",
"src/evicting_map.rs",
"src/fastcdc.rs",
"src/fs.rs",
"src/grpc_utils.rs",
"src/health_utils.rs",
"src/lib.rs",
"src/metrics_utils.rs",
Expand Down
Loading

0 comments on commit 4ec1cb2

Please sign in to comment.