Skip to content

Commit

Permalink
feat(raiko): put the tasks that cannot run in parallel into pending l…
Browse files Browse the repository at this point in the history
…ist (#358)

* put the tasks that cannot run in parallel into pending list

Signed-off-by: smtmfft <smtm@taiko.xyz>

* Update host/src/proof.rs

Co-authored-by: Petar Vujović <petarvujovic98@gmail.com>

* Update host/src/proof.rs

Co-authored-by: Petar Vujović <petarvujovic98@gmail.com>

* fix merge conflicts

* fix compile issue

* Update host/src/proof.rs

Co-authored-by: Petar Vujović <petarvujovic98@gmail.com>

* Update host/src/proof.rs

Co-authored-by: Petar Vujović <petarvujovic98@gmail.com>

* Update host/src/proof.rs

Co-authored-by: Petar Vujović <petarvujovic98@gmail.com>

* Update host/src/proof.rs

Co-authored-by: Petar Vujović <petarvujovic98@gmail.com>

* Update host/src/proof.rs

Co-authored-by: Petar Vujović <petarvujovic98@gmail.com>

* Update host/src/proof.rs

Co-authored-by: Petar Vujović <petarvujovic98@gmail.com>

---------

Signed-off-by: smtmfft <smtm@taiko.xyz>
Co-authored-by: Petar Vujović <petarvujovic98@gmail.com>
  • Loading branch information
smtmfft and petarvujovic98 authored Oct 14, 2024
1 parent eb4d032 commit ec483b7
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 22 deletions.
5 changes: 3 additions & 2 deletions host/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ pub struct ProverState {
pub enum Message {
Cancel(TaskDescriptor),
Task(ProofRequest),
TaskComplete(ProofRequest),
CancelAggregate(AggregationOnlyRequest),
Aggregate(AggregationOnlyRequest),
}
Expand Down Expand Up @@ -200,9 +201,9 @@ impl ProverState {

let opts_clone = opts.clone();
let chain_specs_clone = chain_specs.clone();

let sender = task_channel.clone();
tokio::spawn(async move {
ProofActor::new(receiver, opts_clone, chain_specs_clone)
ProofActor::new(sender, receiver, opts_clone, chain_specs_clone)
.run()
.await;
});
Expand Down
89 changes: 69 additions & 20 deletions host/src/proof.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::{collections::HashMap, str::FromStr, sync::Arc};
use std::{
collections::{HashMap, VecDeque},
str::FromStr,
sync::Arc,
};

use anyhow::anyhow;
use raiko_core::{
Expand All @@ -16,10 +20,13 @@ use raiko_tasks::{get_task_manager, TaskDescriptor, TaskManager, TaskManagerWrap
use reth_primitives::B256;
use tokio::{
select,
sync::{mpsc::Receiver, Mutex, OwnedSemaphorePermit, Semaphore},
sync::{
mpsc::{Receiver, Sender},
Mutex, OwnedSemaphorePermit, Semaphore,
},
};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use tracing::{debug, error, info, warn};

use crate::{
cache,
Expand All @@ -35,32 +42,42 @@ use crate::{
pub struct ProofActor {
opts: Opts,
chain_specs: SupportedChainSpecs,
tasks: Arc<Mutex<HashMap<TaskDescriptor, CancellationToken>>>,
aggregate_tasks: Arc<Mutex<HashMap<AggregationOnlyRequest, CancellationToken>>>,
running_tasks: Arc<Mutex<HashMap<TaskDescriptor, CancellationToken>>>,
pending_tasks: Arc<Mutex<VecDeque<ProofRequest>>>,
receiver: Receiver<Message>,
sender: Sender<Message>,
}

impl ProofActor {
pub fn new(receiver: Receiver<Message>, opts: Opts, chain_specs: SupportedChainSpecs) -> Self {
let tasks = Arc::new(Mutex::new(
pub fn new(
sender: Sender<Message>,
receiver: Receiver<Message>,
opts: Opts,
chain_specs: SupportedChainSpecs,
) -> Self {
let running_tasks = Arc::new(Mutex::new(
HashMap::<TaskDescriptor, CancellationToken>::new(),
));
let aggregate_tasks = Arc::new(Mutex::new(HashMap::<
AggregationOnlyRequest,
CancellationToken,
>::new()));
let pending_tasks = Arc::new(Mutex::new(VecDeque::<ProofRequest>::new()));

Self {
tasks,
aggregate_tasks,
opts,
chain_specs,
aggregate_tasks,
running_tasks,
pending_tasks,
receiver,
sender,
}
}

pub async fn cancel_task(&mut self, key: TaskDescriptor) -> HostResult<()> {
let tasks_map = self.tasks.lock().await;
let tasks_map = self.running_tasks.lock().await;
let Some(task) = tasks_map.get(&key) else {
warn!("No task with those keys to cancel");
return Ok(());
Expand All @@ -85,7 +102,7 @@ impl ProofActor {
Ok(())
}

pub async fn run_task(&mut self, proof_request: ProofRequest, _permit: OwnedSemaphorePermit) {
pub async fn run_task(&mut self, proof_request: ProofRequest) {
let cancel_token = CancellationToken::new();

let Ok((chain_id, blockhash)) = get_task_data(
Expand All @@ -106,10 +123,11 @@ impl ProofActor {
proof_request.prover.clone().to_string(),
));

let mut tasks = self.tasks.lock().await;
let mut tasks = self.running_tasks.lock().await;
tasks.insert(key.clone(), cancel_token.clone());
let sender = self.sender.clone();

let tasks = self.tasks.clone();
let tasks = self.running_tasks.clone();
let opts = self.opts.clone();
let chain_specs = self.chain_specs.clone();

Expand All @@ -118,7 +136,7 @@ impl ProofActor {
_ = cancel_token.cancelled() => {
info!("Task cancelled");
}
result = Self::handle_message(proof_request, key.clone(), &opts, &chain_specs) => {
result = Self::handle_message(proof_request.clone(), key.clone(), &opts, &chain_specs) => {
match result {
Ok(status) => {
info!("Host handling message: {status:?}");
Expand All @@ -131,6 +149,11 @@ impl ProofActor {
}
let mut tasks = tasks.lock().await;
tasks.remove(&key);
// notify complete task to let next pending task run
sender
.send(Message::TaskComplete(proof_request))
.await
.expect("Couldn't send message");
});
}

Expand Down Expand Up @@ -203,21 +226,47 @@ impl ProofActor {
}

pub async fn run(&mut self) {
// recv() is protected by outside mpsc, no lock needed here
let semaphore = Arc::new(Semaphore::new(self.opts.concurrency_limit));

while let Some(message) = self.receiver.recv().await {
match message {
Message::Cancel(key) => {
debug!("Message::Cancel task: {key:?}");
if let Err(error) = self.cancel_task(key).await {
error!("Failed to cancel task: {error}")
}
}
Message::Task(proof_request) => {
let permit = Arc::clone(&semaphore)
.acquire_owned()
.await
.expect("Couldn't acquire permit");
self.run_task(proof_request, permit).await;
debug!("Message::Task proof_request: {proof_request:?}");
let running_task_count = self.running_tasks.lock().await.len();
if running_task_count < self.opts.concurrency_limit {
info!("Running task {proof_request:?}");
self.run_task(proof_request).await;
} else {
info!(
"Task concurrency limit reached, current running {running_task_count:?}, pending: {:?}",
self.pending_tasks.lock().await.len()
);
let mut pending_tasks = self.pending_tasks.lock().await;
pending_tasks.push_back(proof_request);
}
}
Message::TaskComplete(req) => {
// pop up pending task if any task complete
debug!("Message::TaskComplete: {req:?}");
info!(
"task completed, current running {:?}, pending: {:?}",
self.running_tasks.lock().await.len(),
self.pending_tasks.lock().await.len()
);
let mut pending_tasks = self.pending_tasks.lock().await;
if let Some(proof_request) = pending_tasks.pop_front() {
info!("Pop out pending task {proof_request:?}");
self.sender
.send(Message::Task(proof_request))
.await
.expect("Couldn't send message");
}
}
Message::CancelAggregate(request) => {
if let Err(error) = self.cancel_aggregation_task(request).await {
Expand Down Expand Up @@ -326,7 +375,7 @@ pub async fn handle_proof(
store: Option<&mut TaskManagerWrapper>,
) -> HostResult<Proof> {
info!(
"# Generating proof for block {} on {}",
"Generating proof for block {} on {}",
proof_request.block_number, proof_request.network
);

Expand Down

0 comments on commit ec483b7

Please sign in to comment.