Skip to content

Commit

Permalink
[FRAME] Make MQ pallet re-entrancy safe (paritytech#2356)
Browse files Browse the repository at this point in the history
Closes paritytech#2319

Changes:
- Ensure that only `enqueue_message(s)` is callable from within the
message processor. This prevents messed up storage that can currently
happen when the pallet is called into recursively.
- Use `H256` instead of `[u8; 32]` for clearer API.

## Details

The re-entracy check is done with the `environmental` crate by adding a
`with_service_mutex(f)` function that runs the closure exclusively. This
works since the MQ pallet is not instantiable.

---------

Signed-off-by: Oliver Tale-Yazdi <oliver.tale-yazdi@parity.io>
Co-authored-by: Francisco Aguirre <franciscoaguirreperez@gmail.com>
  • Loading branch information
ggwpez and franciscoaguirre authored Dec 7, 2023
1 parent 110e140 commit 0c119a9
Show file tree
Hide file tree
Showing 8 changed files with 472 additions and 67 deletions.
2 changes: 2 additions & 0 deletions substrate/frame/message-queue/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ codec = { package = "parity-scale-codec", version = "3.6.1", default-features =
scale-info = { version = "2.10.0", default-features = false, features = ["derive"] }
serde = { version = "1.0.193", optional = true, features = ["derive"] }
log = { version = "0.4.17", default-features = false }
environmental = { version = "1.1.4", default-features = false }

sp-core = { path = "../../primitives/core", default-features = false }
sp-io = { path = "../../primitives/io", default-features = false }
Expand All @@ -34,6 +35,7 @@ rand_distr = "0.4.3"
default = ["std"]
std = [
"codec/std",
"environmental/std",
"frame-benchmarking?/std",
"frame-support/std",
"frame-system/std",
Expand Down
7 changes: 4 additions & 3 deletions substrate/frame/message-queue/src/benchmarking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use super::{mock_helpers::*, Pallet as MessageQueue, *};
use frame_benchmarking::v2::*;
use frame_support::traits::Get;
use frame_system::RawOrigin;
use sp_io::hashing::blake2_256;
use sp_std::prelude::*;

#[benchmarks(
Expand Down Expand Up @@ -142,7 +143,7 @@ mod benchmarks {
// Check that it was processed.
assert_last_event::<T>(
Event::Processed {
id: sp_io::hashing::blake2_256(&msg),
id: blake2_256(&msg).into(),
origin: 0.into(),
weight_used: 1.into_weight(),
success: true,
Expand Down Expand Up @@ -227,7 +228,7 @@ mod benchmarks {

assert_last_event::<T>(
Event::Processed {
id: sp_io::hashing::blake2_256(&((msgs - 1) as u32).encode()),
id: blake2_256(&((msgs - 1) as u32).encode()).into(),
origin: 0.into(),
weight_used: Weight::from_parts(1, 1),
success: true,
Expand Down Expand Up @@ -264,7 +265,7 @@ mod benchmarks {

assert_last_event::<T>(
Event::Processed {
id: sp_io::hashing::blake2_256(&((msgs - 1) as u32).encode()),
id: blake2_256(&((msgs - 1) as u32).encode()).into(),
origin: 0.into(),
weight_used: Weight::from_parts(1, 1),
success: true,
Expand Down
82 changes: 76 additions & 6 deletions substrate/frame/message-queue/src/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

use crate::{
mock::{
build_and_execute, CountingMessageProcessor, IntoWeight, MockedWeightInfo,
NumMessagesProcessed, YieldingQueues,
build_and_execute, gen_seed, Callback, CountingMessageProcessor, IntoWeight,
MessagesProcessed, MockedWeightInfo, NumMessagesProcessed, YieldingQueues,
},
mock_helpers::MessageOrigin,
*,
Expand Down Expand Up @@ -120,13 +120,13 @@ impl Config for Test {
/// Processing all remaining 28639 messages
/// ```
#[test]
#[ignore] // Only run in the CI.
#[ignore] // Only run in the CI, otherwise its too slow.
fn stress_test_enqueue_and_service() {
let blocks = 20;
let max_queues = 10_000;
let max_messages_per_queue = 10_000;
let max_msg_len = MaxMessageLenOf::<Test>::get();
let mut rng = StdRng::seed_from_u64(43);
let mut rng = StdRng::seed_from_u64(gen_seed());

build_and_execute::<Test>(|| {
let mut msgs_remaining = 0;
Expand All @@ -148,6 +148,74 @@ fn stress_test_enqueue_and_service() {
});
}

/// Very similar to `stress_test_enqueue_and_service`, but enqueues messages while processing them.
#[test]
#[ignore] // Only run in the CI, otherwise its too slow.
fn stress_test_recursive() {
let blocks = 20;
let mut rng = StdRng::seed_from_u64(gen_seed());

// We need to use thread-locals since the callback cannot capture anything.
parameter_types! {
pub static TotalEnqueued: u32 = 0;
pub static Enqueued: u32 = 0;
pub static Called: u32 = 0;
}

Called::take();
Enqueued::take();
TotalEnqueued::take();

Callback::set(Box::new(|_, _| {
let mut rng = StdRng::seed_from_u64(Enqueued::get() as u64);
let max_queues = 1_000;
let max_messages_per_queue = 1_000;
let max_msg_len = MaxMessageLenOf::<Test>::get();

// Instead of directly enqueueing, we enqueue inside a `service` call.
let enqueued = enqueue_messages(max_queues, max_messages_per_queue, max_msg_len, &mut rng);
TotalEnqueued::set(TotalEnqueued::get() + enqueued);
Enqueued::set(Enqueued::get() + enqueued);
Called::set(Called::get() + 1);
}));

build_and_execute::<Test>(|| {
let mut msgs_remaining = 0;
for b in 0..blocks {
log::info!("Block #{}", b);
MessageQueue::enqueue_message(
BoundedSlice::defensive_truncate_from(format!("callback={b}").as_bytes()),
b.into(),
);

msgs_remaining += Enqueued::take() + 1;
// Pick a fraction of all messages currently in queue and process them.
let processed = rng.gen_range(1..=msgs_remaining);
log::info!("Processing {} of all messages {}", processed, msgs_remaining);
process_some_messages(processed); // This also advances the block.
msgs_remaining -= processed;
TotalEnqueued::set(TotalEnqueued::get() - processed + 1);
MessageQueue::do_try_state().unwrap();
}
while Called::get() < blocks {
msgs_remaining += Enqueued::take();
// Pick a fraction of all messages currently in queue and process them.
let processed = rng.gen_range(1..=msgs_remaining);
log::info!("Processing {} of all messages {}", processed, msgs_remaining);
process_some_messages(processed); // This also advances the block.
msgs_remaining -= processed;
TotalEnqueued::set(TotalEnqueued::get() - processed);
MessageQueue::do_try_state().unwrap();
}

let msgs_remaining = TotalEnqueued::take();
log::info!("Processing all remaining {} messages", msgs_remaining);
process_all_messages(msgs_remaining);
assert_eq!(Called::get(), blocks);
post_conditions();
});
}

/// Simulates heavy usage of the suspension logic via `Yield`.
///
/// # Example output
Expand All @@ -164,14 +232,14 @@ fn stress_test_enqueue_and_service() {
/// Processing all remaining 430 messages
/// ```
#[test]
#[ignore] // Only run in the CI.
#[ignore] // Only run in the CI, otherwise its too slow.
fn stress_test_queue_suspension() {
let blocks = 20;
let max_queues = 10_000;
let max_messages_per_queue = 10_000;
let (max_suspend_per_block, max_resume_per_block) = (100, 50);
let max_msg_len = MaxMessageLenOf::<Test>::get();
let mut rng = StdRng::seed_from_u64(43);
let mut rng = StdRng::seed_from_u64(gen_seed());

build_and_execute::<Test>(|| {
let mut suspended = BTreeSet::<u32>::new();
Expand Down Expand Up @@ -300,6 +368,7 @@ fn process_all_messages(expected: u32) {

assert_eq!(consumed, Weight::from_all(expected as u64));
assert_eq!(NumMessagesProcessed::take(), expected as usize);
MessagesProcessed::take();
}

/// Returns the weight consumed by `MessageQueue::on_initialize()`.
Expand Down Expand Up @@ -327,5 +396,6 @@ fn post_conditions() {
assert!(ServiceHead::<Test>::get().is_none());
// This still works fine.
assert_eq!(MessageQueue::service_queues(Weight::MAX), Weight::zero(), "Nothing left");
MessageQueue::do_try_state().unwrap();
next_block();
}
Loading

0 comments on commit 0c119a9

Please sign in to comment.