Skip to content

Commit

Permalink
Implement async collector, avoid blocking the stream task. (#1846)
Browse files Browse the repository at this point in the history
* [stream-task] Stream task implement async collector, avoid blocking the stream task.

* [sync] Refactor collect trait implement.
  • Loading branch information
jolestar authored Dec 20, 2020
1 parent 3b0e31b commit 395b159
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 86 deletions.
164 changes: 105 additions & 59 deletions commons/stream-task/src/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

use crate::{TaskError, TaskEventHandle};
use anyhow::{Error, Result};
use async_std::task::JoinHandle;
use futures::channel::mpsc::{channel, Sender};
use futures::task::{Context, Poll};
use futures::Sink;
use futures::{Sink, StreamExt};
use log::debug;
use pin_project::pin_project;
use pin_utils::core_reexport::option::Option::Some;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
Expand All @@ -22,10 +26,7 @@ pub enum CollectorState {
pub trait TaskResultCollector<Item>: std::marker::Send + Unpin {
type Output: std::marker::Send;

fn collect(self: Pin<&mut Self>, item: Item) -> Result<CollectorState>;
fn flush(self: Pin<&mut Self>) -> Result<CollectorState> {
Ok(CollectorState::Need)
}
fn collect(&mut self, item: Item) -> Result<CollectorState>;
fn finish(self) -> Result<Self::Output>;
}

Expand All @@ -36,8 +37,8 @@ where
{
type Output = ();

fn collect(self: Pin<&mut Self>, item: Item) -> Result<CollectorState> {
self.get_mut()(item)?;
fn collect(&mut self, item: Item) -> Result<CollectorState> {
(self)(item)?;
Ok(CollectorState::Need)
}

Expand All @@ -52,8 +53,8 @@ where
{
type Output = Self;

fn collect(self: Pin<&mut Self>, item: Item) -> Result<CollectorState> {
self.get_mut().push(item);
fn collect(&mut self, item: Item) -> Result<CollectorState> {
self.push(item);
Ok(CollectorState::Need)
}

Expand All @@ -62,18 +63,35 @@ where
}
}

#[derive(Clone, Default)]
#[derive(Clone)]
pub struct CounterCollector {
counter: Arc<AtomicU64>,
max: u64,
}

impl Default for CounterCollector {
fn default() -> Self {
Self::new()
}
}

impl CounterCollector {
pub fn new() -> Self {
Self::default()
Self::new_with_counter(Arc::new(AtomicU64::default()))
}

pub fn new_with_counter(counter: Arc<AtomicU64>) -> Self {
Self { counter }
Self {
counter,
max: u64::max_value(),
}
}

pub fn new_with_max(max: u64) -> Self {
Self {
counter: Arc::new(AtomicU64::default()),
max,
}
}
}

Expand All @@ -83,9 +101,15 @@ where
{
type Output = u64;

fn collect(self: Pin<&mut Self>, _item: Item) -> Result<CollectorState, Error> {
fn collect(&mut self, _item: Item) -> Result<CollectorState, Error> {
self.counter.fetch_add(1, Ordering::SeqCst);
Ok(CollectorState::Need)
let count = self.counter.load(Ordering::SeqCst);
debug!("collect item, count: {}", count);
if count >= self.max {
Ok(CollectorState::Enough)
} else {
Ok(CollectorState::Need)
}
}

fn finish(self) -> Result<Self::Output> {
Expand All @@ -97,77 +121,99 @@ where
pub(crate) enum SinkError {
#[error("{0:?}")]
StreamTaskError(TaskError),
#[error("{0:?}")]
CollectorError(anyhow::Error),
#[error("Collector is enough.")]
CollectorEnough,
}

impl SinkError {
pub fn map_result(result: Result<(), SinkError>) -> Result<(), TaskError> {
match result {
Err(err) => match err {
SinkError::StreamTaskError(err) => Err(err),
//SinkError::CollectorError(err) => Err(TaskError::CollectorError(err)),
SinkError::CollectorEnough => Ok(()),
},
Ok(()) => Ok(()),
}
}
}

#[pin_project]
pub(crate) struct FutureTaskSink<C> {
pub(crate) struct FutureTaskSink<Item, Output> {
#[pin]
sender: Sender<Item>,
#[pin]
collector: C,
event_handle: Arc<dyn TaskEventHandle>,
task_handle: JoinHandle<Result<Output, TaskError>>,
}

impl<C> FutureTaskSink<C> {
pub fn new<Item>(collector: C, event_handle: Arc<dyn TaskEventHandle>) -> Self
impl<Item, Output> FutureTaskSink<Item, Output> {
pub fn new<C>(
mut collector: C,
buffer_size: usize,
event_handle: Arc<dyn TaskEventHandle>,
) -> Self
where
C: TaskResultCollector<Item>,
Item: Send + 'static,
Output: Send + 'static,
C: TaskResultCollector<Item, Output = Output> + 'static,
{
let (sender, receiver) = channel(buffer_size);
let task_handle = async_std::task::spawn(async move {
let mut receiver = receiver.fuse();
while let Some(item) = receiver.next().await {
event_handle.on_item();
let collector_state = collector.collect(item).map_err(TaskError::CollectorError)?;
match collector_state {
CollectorState::Enough => break,
CollectorState::Need => {
//continue
}
}
}
collector.finish().map_err(TaskError::CollectorError)
});
Self {
collector,
event_handle,
sender,
task_handle,
}
}

pub fn into_collector(self) -> C {
self.collector
}

fn flush_inner<Item>(self: Pin<&mut Self>) -> Poll<Result<(), SinkError>>
where
C: TaskResultCollector<Item>,
{
let this = self.project();
match this.collector.flush() {
Err(e) => Poll::Ready(Err(SinkError::CollectorError(e))),
Ok(state) => match state {
CollectorState::Need => Poll::Ready(Ok(())),
CollectorState::Enough => Poll::Ready(Err(SinkError::CollectorEnough)),
},
}
pub async fn wait_output(self) -> Result<Output, TaskError> {
self.task_handle.await
}
}

impl<C, Item> Sink<Item> for FutureTaskSink<C>
where
C: TaskResultCollector<Item>,
{
impl<Item, Output> Sink<Item> for FutureTaskSink<Item, Output> {
type Error = SinkError;

fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let this = self.project();
//if the sender is disconnect, means the task is finished, so map error to CollectorEnough, and close the sink.
this.sender
.poll_ready(cx)
.map_err(|_| SinkError::CollectorEnough)
}

fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
let this = self.project();
this.event_handle.on_item();
let collector_state = this
.collector
.collect(item)
.map_err(SinkError::CollectorError)?;
match collector_state {
CollectorState::Enough => Err(SinkError::CollectorEnough),
CollectorState::Need => Ok(()),
}
//ignore sender error, because if send error, may bean task is finished
this.sender
.start_send(item)
.map_err(|_| SinkError::CollectorEnough)
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.flush_inner()
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let this = self.as_mut().project();
this.sender
.poll_flush(cx)
.map_err(|_| SinkError::CollectorEnough)
}

fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.flush_inner()
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
debug!("FutureTaskSink poll_close");
let this = self.as_mut().project();
this.sender
.poll_close(cx)
.map_err(|_| SinkError::CollectorEnough)
}
}
19 changes: 6 additions & 13 deletions commons/stream-task/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,13 @@ where
})
.flatten()
.map_err(SinkError::StreamTaskError);
let mut sink = FutureTaskSink::new(self.collector, event_handle.clone());
let sink_result = sink.send_all(&mut buffered_stream).await;
if let Err(sink_err) = sink_result {
match sink_err {
SinkError::StreamTaskError(e) => return Err(e),
SinkError::CollectorError(e) => return Err(TaskError::CollectorError(e)),
SinkError::CollectorEnough => {
//continue
}
}
}
let collector = sink.into_collector();
let mut sink =
FutureTaskSink::new(self.collector, self.buffer_size, event_handle.clone());
SinkError::map_result(sink.send_all(&mut buffered_stream).await)?;
SinkError::map_result(sink.close().await)?;
let output = sink.wait_output().await?;
event_handle.on_finish(task_name.to_string());
collector.finish().map_err(TaskError::CollectorError)
Ok(output)
}
.boxed();

Expand Down
24 changes: 23 additions & 1 deletion commons/stream-task/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,11 +416,33 @@ mod tests {
assert!(result.is_err());
let task_err = result.err().unwrap();
assert!(task_err.is_break_error());
assert_eq!(break_at, counter.load(Ordering::SeqCst));
assert_eq!(break_at, counter.load(Ordering::SeqCst) + 1);

let report = event_handle.get_reports().pop().unwrap();
debug!("{}", report);
assert!(report.processed_items > 0);
assert!(report.processed_items < max);
}

#[stest::test]
async fn test_collect_enough() {
let max = 100;
let collector_max = 50;
let config = MockTestConfig::new_with_max(max);
let mock_state = MockTaskState::new(config);

let event_handle = Arc::new(TaskEventCounterHandle::new());
let result = TaskGenerator::new(
mock_state.clone(),
10,
0,
0,
CounterCollector::new_with_max(collector_max),
event_handle,
)
.generate()
.await;
//assert!(result.is_ok());
assert_eq!(result.unwrap(), collector_max);
}
}
9 changes: 2 additions & 7 deletions sync/src/tasks/accumulator_sync_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use starcoin_accumulator::accumulator_info::AccumulatorInfo;
use starcoin_accumulator::{Accumulator, AccumulatorTreeStore, MerkleAccumulator};
use starcoin_crypto::HashValue;
use starcoin_types::block::{BlockIdAndNumber, BlockNumber};
use std::pin::Pin;
use std::sync::Arc;
use stream_task::{CollectorState, TaskResultCollector, TaskState};

Expand Down Expand Up @@ -105,20 +104,16 @@ impl AccumulatorCollector {
impl TaskResultCollector<HashValue> for AccumulatorCollector {
type Output = (BlockIdAndNumber, MerkleAccumulator);

fn collect(self: Pin<&mut Self>, item: HashValue) -> Result<CollectorState> {
fn collect(&mut self, item: HashValue) -> Result<CollectorState> {
self.accumulator.append(&[item])?;
self.accumulator.flush()?;
if self.accumulator.num_leaves() == self.target.num_leaves {
Ok(CollectorState::Enough)
} else {
Ok(CollectorState::Need)
}
}

fn flush(self: Pin<&mut Self>) -> Result<CollectorState> {
self.accumulator.flush()?;
Ok(CollectorState::Need)
}

fn finish(self) -> Result<Self::Output> {
let info = self.accumulator.get_info();
ensure!(
Expand Down
3 changes: 1 addition & 2 deletions sync/src/tasks/block_sync_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use starcoin_types::block::{Block, BlockInfo, BlockNumber};
use starcoin_types::peer_info::PeerId;
use starcoin_vm_types::on_chain_config::GlobalTimeOnChain;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use stream_task::{CollectorState, TaskResultCollector, TaskState};

Expand Down Expand Up @@ -263,7 +262,7 @@ where
{
type Output = BlockChain;

fn collect(mut self: Pin<&mut Self>, item: SyncBlockData) -> Result<CollectorState> {
fn collect(&mut self, item: SyncBlockData) -> Result<CollectorState> {
let (block, block_info, peer_id) = item.into();
let block_id = block.id();
let timestamp = block.header().timestamp;
Expand Down
Loading

0 comments on commit 395b159

Please sign in to comment.