Skip to content

Commit

Permalink
scheduler opt-in forwarding (#1801)
Browse files Browse the repository at this point in the history
  • Loading branch information
apfitzge authored Jul 12, 2024
1 parent ac9ff97 commit 61d8be0
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 52 deletions.
1 change: 1 addition & 0 deletions banking-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ fn main() {
Arc::new(connection_cache),
bank_forks.clone(),
&Arc::new(PrioritizationFeeCache::new(0u64)),
false,
);

// This is so that the signal_receiver does not go out of scope after the closure.
Expand Down
1 change: 1 addition & 0 deletions core/benches/banking_stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ fn bench_banking(bencher: &mut Bencher, tx_type: TransactionType) {
Arc::new(ConnectionCache::new("connection_cache_test")),
bank_forks,
&Arc::new(PrioritizationFeeCache::new(0u64)),
false,
);

let chunk_len = verified.len() / CHUNKS;
Expand Down
25 changes: 18 additions & 7 deletions core/src/banking_stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ impl BankingStage {
connection_cache: Arc<ConnectionCache>,
bank_forks: Arc<RwLock<BankForks>>,
prioritization_fee_cache: &Arc<PrioritizationFeeCache>,
enable_forwarding: bool,
) -> Self {
Self::new_num_threads(
block_production_method,
Expand All @@ -354,6 +355,7 @@ impl BankingStage {
connection_cache,
bank_forks,
prioritization_fee_cache,
enable_forwarding,
)
}

Expand All @@ -372,6 +374,7 @@ impl BankingStage {
connection_cache: Arc<ConnectionCache>,
bank_forks: Arc<RwLock<BankForks>>,
prioritization_fee_cache: &Arc<PrioritizationFeeCache>,
enable_forwarding: bool,
) -> Self {
match block_production_method {
BlockProductionMethod::ThreadLocalMultiIterator => {
Expand Down Expand Up @@ -403,6 +406,7 @@ impl BankingStage {
connection_cache,
bank_forks,
prioritization_fee_cache,
enable_forwarding,
),
}
}
Expand Down Expand Up @@ -505,6 +509,7 @@ impl BankingStage {
connection_cache: Arc<ConnectionCache>,
bank_forks: Arc<RwLock<BankForks>>,
prioritization_fee_cache: &Arc<PrioritizationFeeCache>,
enable_forwarding: bool,
) -> Self {
assert!(num_threads >= MIN_TOTAL_THREADS);
// Single thread to generate entries from many banks.
Expand Down Expand Up @@ -586,13 +591,15 @@ impl BankingStage {
)
}

let forwarder = Forwarder::new(
poh_recorder.clone(),
bank_forks.clone(),
cluster_info.clone(),
connection_cache.clone(),
data_budget.clone(),
);
let forwarder = enable_forwarding.then(|| {
Forwarder::new(
poh_recorder.clone(),
bank_forks.clone(),
cluster_info.clone(),
connection_cache.clone(),
data_budget.clone(),
)
});

// Spawn the central scheduler thread
bank_thread_hdls.push({
Expand Down Expand Up @@ -883,6 +890,7 @@ mod tests {
Arc::new(ConnectionCache::new("connection_cache_test")),
bank_forks,
&Arc::new(PrioritizationFeeCache::new(0u64)),
false,
);
drop(non_vote_sender);
drop(tpu_vote_sender);
Expand Down Expand Up @@ -938,6 +946,7 @@ mod tests {
Arc::new(ConnectionCache::new("connection_cache_test")),
bank_forks,
&Arc::new(PrioritizationFeeCache::new(0u64)),
false,
);
trace!("sending bank");
drop(non_vote_sender);
Expand Down Expand Up @@ -1017,6 +1026,7 @@ mod tests {
Arc::new(ConnectionCache::new("connection_cache_test")),
bank_forks.clone(), // keep a local-copy of bank-forks so worker threads do not lose weak access to bank-forks
&Arc::new(PrioritizationFeeCache::new(0u64)),
false,
);

// fund another account so we can send 2 good transactions in a single batch.
Expand Down Expand Up @@ -1378,6 +1388,7 @@ mod tests {
Arc::new(ConnectionCache::new("connection_cache_test")),
bank_forks,
&Arc::new(PrioritizationFeeCache::new(0u64)),
false,
);

let keypairs = (0..100).map(|_| Keypair::new()).collect_vec();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ pub(crate) struct SchedulerController {
timing_metrics: SchedulerTimingMetrics,
/// Metric report handles for the worker threads.
worker_metrics: Vec<Arc<ConsumeWorkerMetrics>>,
/// State for forwarding packets to the leader.
forwarder: Forwarder,
/// State for forwarding packets to the leader, if enabled.
forwarder: Option<Forwarder>,
}

impl SchedulerController {
Expand All @@ -76,7 +76,7 @@ impl SchedulerController {
bank_forks: Arc<RwLock<BankForks>>,
scheduler: PrioGraphScheduler,
worker_metrics: Vec<Arc<ConsumeWorkerMetrics>>,
forwarder: Forwarder,
forwarder: Option<Forwarder>,
) -> Self {
Self {
decision_maker,
Expand Down Expand Up @@ -147,6 +147,7 @@ impl SchedulerController {
&mut self,
decision: &BufferedPacketsDecision,
) -> Result<(), SchedulerError> {
let forwarding_enabled = self.forwarder.is_some();
match decision {
BufferedPacketsDecision::Consume(bank_start) => {
let (scheduling_summary, schedule_time_us) = measure_us!(self.scheduler.schedule(
Expand Down Expand Up @@ -186,16 +187,30 @@ impl SchedulerController {
});
}
BufferedPacketsDecision::Forward => {
let (_, forward_time_us) = measure_us!(self.forward_packets(false));
self.timing_metrics.update(|timing_metrics| {
saturating_add_assign!(timing_metrics.forward_time_us, forward_time_us);
});
if forwarding_enabled {
let (_, forward_time_us) = measure_us!(self.forward_packets(false));
self.timing_metrics.update(|timing_metrics| {
saturating_add_assign!(timing_metrics.forward_time_us, forward_time_us);
});
} else {
let (_, clear_time_us) = measure_us!(self.clear_container());
self.timing_metrics.update(|timing_metrics| {
saturating_add_assign!(timing_metrics.clear_time_us, clear_time_us);
});
}
}
BufferedPacketsDecision::ForwardAndHold => {
let (_, forward_time_us) = measure_us!(self.forward_packets(true));
self.timing_metrics.update(|timing_metrics| {
saturating_add_assign!(timing_metrics.forward_time_us, forward_time_us);
});
if forwarding_enabled {
let (_, forward_time_us) = measure_us!(self.forward_packets(true));
self.timing_metrics.update(|timing_metrics| {
saturating_add_assign!(timing_metrics.forward_time_us, forward_time_us);
});
} else {
let (_, clean_time_us) = measure_us!(self.clean_queue());
self.timing_metrics.update(|timing_metrics| {
saturating_add_assign!(timing_metrics.clean_time_us, clean_time_us);
});
}
}
BufferedPacketsDecision::Hold => {}
}
Expand Down Expand Up @@ -234,6 +249,7 @@ impl SchedulerController {
let start = Instant::now();
let bank = self.bank_forks.read().unwrap().working_bank();
let feature_set = &bank.feature_set;
let forwarder = self.forwarder.as_mut().expect("forwarder must exist");

// Pop from the container in chunks, filter using bank checks, then attempt to forward.
// This doubles as a way to clean the queue as well as forwarding transactions.
Expand Down Expand Up @@ -282,7 +298,7 @@ impl SchedulerController {

// If not already forwarded and can be forwarded, add to forwardable packets.
if state.should_forward()
&& self.forwarder.try_add_packet(
&& forwarder.try_add_packet(
sanitized_transaction,
immutable_packet,
feature_set,
Expand All @@ -300,9 +316,8 @@ impl SchedulerController {
}

// Forward each batch of transactions
self.forwarder
.forward_batched_packets(&ForwardOption::ForwardTransaction);
self.forwarder.clear_batches();
forwarder.forward_batched_packets(&ForwardOption::ForwardTransaction);
forwarder.clear_batches();

// If we hit the time limit. Drop everything that was not checked/processed.
// If we cannot run these simple checks in time, then we cannot run them during
Expand Down Expand Up @@ -330,7 +345,6 @@ impl SchedulerController {

/// Clears the transaction state container.
/// This only clears pending transactions, and does **not** clear in-flight transactions.
#[allow(dead_code)]
fn clear_container(&mut self) {
let mut num_dropped_on_clear: usize = 0;
while let Some(id) = self.container.pop() {
Expand All @@ -346,7 +360,6 @@ impl SchedulerController {
/// Clean unprocessable transactions from the queue. These will be transactions that are
/// expired, already processed, or are no longer sanitizable.
/// This only clears pending transactions, and does **not** clear in-flight transactions.
#[allow(dead_code)]
fn clean_queue(&mut self) {
// Clean up any transactions that have already been processed, are too old, or do not have
// valid nonce accounts.
Expand Down Expand Up @@ -424,17 +437,19 @@ impl SchedulerController {
let remaining_queue_capacity = self.container.remaining_queue_capacity();

const MAX_PACKET_RECEIVE_TIME: Duration = Duration::from_millis(100);
let recv_timeout = match decision {
BufferedPacketsDecision::Consume(_) => {
let (recv_timeout, should_buffer) = match decision {
BufferedPacketsDecision::Consume(_) => (
if self.container.is_empty() {
MAX_PACKET_RECEIVE_TIME
} else {
Duration::ZERO
}
},
true,
),
BufferedPacketsDecision::Forward => (MAX_PACKET_RECEIVE_TIME, false),
BufferedPacketsDecision::ForwardAndHold | BufferedPacketsDecision::Hold => {
(MAX_PACKET_RECEIVE_TIME, true)
}
BufferedPacketsDecision::Forward
| BufferedPacketsDecision::ForwardAndHold
| BufferedPacketsDecision::Hold => MAX_PACKET_RECEIVE_TIME,
};

let (received_packet_results, receive_time_us) = measure_us!(self
Expand All @@ -456,11 +471,21 @@ impl SchedulerController {
saturating_add_assign!(count_metrics.num_received, num_received_packets);
});

let (_, buffer_time_us) =
measure_us!(self.buffer_packets(receive_packet_results.deserialized_packets));
self.timing_metrics.update(|timing_metrics| {
saturating_add_assign!(timing_metrics.buffer_time_us, buffer_time_us);
});
if should_buffer {
let (_, buffer_time_us) = measure_us!(
self.buffer_packets(receive_packet_results.deserialized_packets)
);
self.timing_metrics.update(|timing_metrics| {
saturating_add_assign!(timing_metrics.buffer_time_us, buffer_time_us);
});
} else {
self.count_metrics.update(|count_metrics| {
saturating_add_assign!(
count_metrics.num_dropped_on_receive,
num_received_packets
);
});
}
}
Err(RecvTimeoutError::Timeout) => {}
Err(RecvTimeoutError::Disconnected) => return false,
Expand Down Expand Up @@ -636,14 +661,13 @@ mod tests {
banking_stage::{
consumer::TARGET_NUM_TRANSACTIONS_PER_BATCH,
scheduler_messages::{ConsumeWork, FinishedConsumeWork, TransactionBatchId},
tests::{create_slow_genesis_config, new_test_cluster_info},
tests::create_slow_genesis_config,
},
banking_trace::BankingPacketBatch,
sigverify::SigverifyTracerPacketStats,
},
crossbeam_channel::{unbounded, Receiver, Sender},
itertools::Itertools,
solana_client::connection_cache::ConnectionCache,
solana_ledger::{
blockstore::Blockstore, genesis_utils::GenesisConfigInfo,
get_tmp_ledger_path_auto_delete, leader_schedule_cache::LeaderScheduleCache,
Expand Down Expand Up @@ -712,17 +736,6 @@ mod tests {
let (consume_work_senders, consume_work_receivers) = create_channels(num_threads);
let (finished_consume_work_sender, finished_consume_work_receiver) = unbounded();

let validator_keypair = Arc::new(Keypair::new());
let (_local_node, cluster_info) = new_test_cluster_info(Some(validator_keypair));
let cluster_info = Arc::new(cluster_info);
let forwarder = Forwarder::new(
poh_recorder.clone(),
bank_forks.clone(),
cluster_info,
Arc::new(ConnectionCache::new("connection_cache_test")),
Arc::default(),
);

let test_frame = TestFrame {
bank,
mint_keypair,
Expand All @@ -741,7 +754,7 @@ mod tests {
bank_forks,
PrioGraphScheduler::new(consume_work_senders, finished_consume_work_receiver),
vec![], // no actual workers with metrics to report, this can be empty
forwarder,
None,
);

(test_frame, scheduler_controller)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ pub struct SchedulerTimingMetricsInner {
pub schedule_filter_time_us: u64,
/// Time spent scheduling transactions.
pub schedule_time_us: u64,
/// Time spent clearing transactions from the container.
pub clear_time_us: u64,
/// Time spent cleaning expired or processed transactions from the container.
pub clean_time_us: u64,
/// Time spent forwarding transactions.
pub forward_time_us: u64,
/// Time spent receiving completed transactions.
Expand Down Expand Up @@ -312,6 +316,8 @@ impl SchedulerTimingMetricsInner {
("buffer_time_us", self.buffer_time_us, i64),
("schedule_filter_time_us", self.schedule_filter_time_us, i64),
("schedule_time_us", self.schedule_time_us, i64),
("clear_time_us", self.clear_time_us, i64),
("clean_time_us", self.clean_time_us, i64),
("forward_time_us", self.forward_time_us, i64),
(
"receive_completed_time_us",
Expand All @@ -331,6 +337,8 @@ impl SchedulerTimingMetricsInner {
self.buffer_time_us = 0;
self.schedule_filter_time_us = 0;
self.schedule_time_us = 0;
self.clear_time_us = 0;
self.clean_time_us = 0;
self.forward_time_us = 0;
self.receive_completed_time_us = 0;
}
Expand Down
2 changes: 2 additions & 0 deletions core/src/tpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ impl Tpu {
tpu_max_connections_per_ipaddr_per_minute: u64,
prioritization_fee_cache: &Arc<PrioritizationFeeCache>,
block_production_method: BlockProductionMethod,
enable_block_production_forwarding: bool,
_generator_config: Option<GeneratorConfig>, /* vestigial code for replay invalidator */
) -> (Self, Vec<Arc<dyn NotifyKeyUpdate + Sync + Send>>) {
let TpuSockets {
Expand Down Expand Up @@ -246,6 +247,7 @@ impl Tpu {
connection_cache.clone(),
bank_forks.clone(),
prioritization_fee_cache,
enable_block_production_forwarding,
);

let (entry_receiver, tpu_entry_notifier) =
Expand Down
4 changes: 4 additions & 0 deletions core/src/validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ pub struct ValidatorConfig {
pub banking_trace_dir_byte_limit: banking_trace::DirByteLimit,
pub block_verification_method: BlockVerificationMethod,
pub block_production_method: BlockProductionMethod,
pub enable_block_production_forwarding: bool,
pub generator_config: Option<GeneratorConfig>,
pub use_snapshot_archives_at_startup: UseSnapshotArchivesAtStartup,
pub wen_restart_proto_path: Option<PathBuf>,
Expand Down Expand Up @@ -337,6 +338,7 @@ impl Default for ValidatorConfig {
banking_trace_dir_byte_limit: 0,
block_verification_method: BlockVerificationMethod::default(),
block_production_method: BlockProductionMethod::default(),
enable_block_production_forwarding: false,
generator_config: None,
use_snapshot_archives_at_startup: UseSnapshotArchivesAtStartup::default(),
wen_restart_proto_path: None,
Expand All @@ -355,6 +357,7 @@ impl ValidatorConfig {
enforce_ulimit_nofile: false,
rpc_config: JsonRpcConfig::default_for_test(),
block_production_method: BlockProductionMethod::default(),
enable_block_production_forwarding: true, // enable forwarding by default for tests
replay_forks_threads: NonZeroUsize::new(1).expect("1 is non-zero"),
replay_transactions_threads: NonZeroUsize::new(get_max_thread_count())
.expect("thread count is non-zero"),
Expand Down Expand Up @@ -1432,6 +1435,7 @@ impl Validator {
tpu_max_connections_per_ipaddr_per_minute,
&prioritization_fee_cache,
config.block_production_method.clone(),
config.enable_block_production_forwarding,
config.generator_config.clone(),
);

Expand Down
Loading

0 comments on commit 61d8be0

Please sign in to comment.