Skip to content

Commit

Permalink
refactor(core): Migrate RangeWrite to ConcurrentTasks (#4696)
Browse files Browse the repository at this point in the history
Signed-off-by: Xuanwo <github@xuanwo.io>
  • Loading branch information
Xuanwo authored Jun 5, 2024
1 parent 63ed080 commit 2034a26
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 105 deletions.
203 changes: 99 additions & 104 deletions core/src/raw/oio/write/range_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@
// specific language governing permissions and limitations
// under the License.

use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;

use futures::Future;
use futures::FutureExt;
use futures::StreamExt;
use futures::{select, Future};

use crate::raw::*;
use crate::*;
Expand Down Expand Up @@ -89,74 +85,75 @@ pub trait RangeWrite: Send + Sync + Unpin + 'static {
fn abort_range(&self, location: &str) -> impl Future<Output = Result<()>> + MaybeSend;
}

/// WritePartResult is the result returned by [`WriteRangeFuture`].
///
/// The error part will carries input `(offset, bytes, err)` so caller can retry them.
type WriteRangeResult = std::result::Result<(), (u64, Buffer, Error)>;

struct WriteRangeFuture(BoxedStaticFuture<WriteRangeResult>);

/// # Safety
///
/// wasm32 is a special target that we only have one event-loop for this WriteRangeFuture.
unsafe impl Send for WriteRangeFuture {}

/// # Safety
///
/// We will only take `&mut Self` reference for WriteRangeFuture.
unsafe impl Sync for WriteRangeFuture {}

impl Future for WriteRangeFuture {
type Output = WriteRangeResult;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.get_mut().0.poll_unpin(cx)
}
}

impl WriteRangeFuture {
pub fn new<W: RangeWrite>(
w: Arc<W>,
location: Arc<String>,
offset: u64,
bytes: Buffer,
) -> Self {
let fut = async move {
w.write_range(&location, offset, bytes.len() as u64, bytes.clone())
.await
.map_err(|err| (offset, bytes, err))
};
struct WriteInput<W: RangeWrite> {
w: Arc<W>,
executor: Executor,

WriteRangeFuture(Box::pin(fut))
}
location: Arc<String>,
offset: u64,
bytes: Buffer,
}

/// RangeWriter will implements [`oio::Write`] based on range write.
pub struct RangeWriter<W: RangeWrite> {
w: Arc<W>,
executor: Executor,

location: Option<Arc<String>>,
next_offset: u64,
buffer: Option<Buffer>,
futures: ConcurrentFutures<WriteRangeFuture>,

w: Arc<W>,
cache: Option<Buffer>,
tasks: ConcurrentTasks<WriteInput<W>, ()>,
}

impl<W: RangeWrite> RangeWriter<W> {
/// Create a new MultipartWriter.
pub fn new(inner: W, concurrent: usize) -> Self {
pub fn new(inner: W, executor: Option<Executor>, concurrent: usize) -> Self {
let executor = executor.unwrap_or_default();

Self {
w: Arc::new(inner),

futures: ConcurrentFutures::new(1.max(concurrent)),
buffer: None,
executor: executor.clone(),
location: None,
next_offset: 0,
cache: None,

tasks: ConcurrentTasks::new(executor, concurrent, |input| {
Box::pin(async move {
let fut = input.w.write_range(
&input.location,
input.offset,
input.bytes.len() as u64,
input.bytes.clone(),
);
match input.executor.timeout() {
None => {
let result = fut.await;
(input, result)
}
Some(timeout) => {
let result = select! {
result = fut.fuse() => {
result
}
_ = timeout.fuse() => {
Err(Error::new(
ErrorKind::Unexpected, "write range timeout")
.with_context("offset", input.offset.to_string())
.set_temporary())
}
};
(input, result)
}
}
})
}),
}
}

fn fill_cache(&mut self, bs: Buffer) -> usize {
let size = bs.len();
assert!(self.buffer.is_none());
self.buffer = Some(bs);
assert!(self.cache.is_none());
self.cache = Some(bs);
size
}
}
Expand All @@ -167,7 +164,7 @@ impl<W: RangeWrite> oio::Write for RangeWriter<W> {
Some(location) => location,
None => {
// Fill cache with the first write.
if self.buffer.is_none() {
if self.cache.is_none() {
let size = self.fill_cache(bs);
return Ok(size);
}
Expand All @@ -179,64 +176,46 @@ impl<W: RangeWrite> oio::Write for RangeWriter<W> {
}
};

loop {
if self.futures.has_remaining() {
let cache = self.buffer.take().expect("cache must be valid");
let offset = self.next_offset;
self.next_offset += cache.len() as u64;
self.futures.push_back(WriteRangeFuture::new(
self.w.clone(),
location,
offset,
cache,
));

let size = self.fill_cache(bs);
return Ok(size);
}

if let Some(Err((offset, bytes, err))) = self.futures.next().await {
self.futures.push_front(WriteRangeFuture::new(
self.w.clone(),
location,
offset,
bytes,
));
return Err(err);
}
}
let bytes = self.cache.clone().expect("pending write must exist");
let length = bytes.len() as u64;
let offset = self.next_offset;

self.tasks
.execute(WriteInput {
w: self.w.clone(),
executor: self.executor.clone(),
location,
offset,
bytes,
})
.await?;
self.cache = None;
self.next_offset += length;
let size = self.fill_cache(bs);
Ok(size)
}

async fn close(&mut self) -> Result<()> {
let Some(location) = self.location.clone() else {
let (size, body) = match self.buffer.clone() {
let (size, body) = match self.cache.clone() {
Some(cache) => (cache.len(), cache),
None => (0, Buffer::new()),
};
// Call write_once if there is no data in buffer and no location.
return self.w.write_once(size as u64, body).await;
self.w.write_once(size as u64, body).await?;
self.cache = None;
return Ok(());
};

if !self.futures.is_empty() {
while let Some(result) = self.futures.next().await {
if let Err((offset, bytes, err)) = result {
self.futures.push_front(WriteRangeFuture::new(
self.w.clone(),
location,
offset,
bytes,
));
return Err(err);
};
}
}
// Make sure all tasks are finished.
while self.tasks.next().await.transpose()?.is_some() {}

if let Some(buffer) = self.buffer.clone() {
if let Some(buffer) = self.cache.clone() {
let offset = self.next_offset;
self.w
.complete_range(&location, offset, buffer.len() as u64, buffer)
.await?;
self.buffer = None;
self.cache = None;
}

Ok(())
Expand All @@ -247,10 +226,9 @@ impl<W: RangeWrite> oio::Write for RangeWriter<W> {
return Ok(());
};

self.futures.clear();
self.tasks.clear();
self.cache = None;
self.w.abort_range(&location).await?;
// Clean cache when abort_range returns success.
self.buffer = None;
Ok(())
}
}
Expand All @@ -259,11 +237,13 @@ impl<W: RangeWrite> oio::Write for RangeWriter<W> {
mod tests {
use std::collections::HashSet;
use std::sync::Mutex;
use std::time::Duration;

use pretty_assertions::assert_eq;
use rand::thread_rng;
use rand::Rng;
use rand::RngCore;
use tokio::time::sleep;

use super::*;
use crate::raw::oio::Write;
Expand Down Expand Up @@ -298,9 +278,14 @@ mod tests {
}

async fn write_range(&self, _: &str, offset: u64, size: u64, _: Buffer) -> Result<()> {
// We will have 50% percent rate for write part to fail.
if thread_rng().gen_bool(5.0 / 10.0) {
return Err(Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!"));
// Add an async sleep here to enforce some pending.
sleep(Duration::from_millis(50)).await;

// We will have 10% percent rate for write part to fail.
if thread_rng().gen_bool(1.0 / 10.0) {
return Err(
Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
);
}

let mut test = self.lock().unwrap();
Expand All @@ -318,6 +303,16 @@ mod tests {
}

async fn complete_range(&self, _: &str, offset: u64, size: u64, _: Buffer) -> Result<()> {
// Add an async sleep here to enforce some pending.
sleep(Duration::from_millis(50)).await;

// We will have 10% percent rate for write part to fail.
if thread_rng().gen_bool(1.0 / 10.0) {
return Err(
Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
);
}

let mut test = self.lock().unwrap();
test.length += size;

Expand All @@ -340,7 +335,7 @@ mod tests {
async fn test_range_writer_with_concurrent_errors() {
let mut rng = thread_rng();

let mut w = RangeWriter::new(TestWrite::new(), 8);
let mut w = RangeWriter::new(TestWrite::new(), Some(Executor::new()), 200);
let mut total_size = 0u64;

for _ in 0..1000 {
Expand Down
3 changes: 2 additions & 1 deletion core/src/services/gcs/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,9 @@ impl Access for GcsBackend {

async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> {
let concurrent = args.concurrent();
let executor = args.executor().cloned();
let w = GcsWriter::new(self.core.clone(), path, args);
let w = oio::RangeWriter::new(w, concurrent);
let w = oio::RangeWriter::new(w, executor, concurrent);

Ok((RpWrite::default(), w))
}
Expand Down

0 comments on commit 2034a26

Please sign in to comment.