diff --git a/parquet/src/arrow/async_writer/mod.rs b/parquet/src/arrow/async_writer/mod.rs index dc000f248c9b..abfb1c54ed44 100644 --- a/parquet/src/arrow/async_writer/mod.rs +++ b/parquet/src/arrow/async_writer/mod.rs @@ -51,10 +51,7 @@ //! # } //! ``` -use std::{ - io::Write, - sync::{Arc, Mutex}, -}; +use std::{io::Write, sync::Arc}; use crate::{ arrow::ArrowWriter, @@ -80,22 +77,24 @@ pub struct AsyncArrowWriter { /// The inner buffer shared by the `sync_writer` and the `async_writer` shared_buffer: SharedBuffer, - - /// The threshold triggering buffer flush - buffer_flush_threshold: usize, } impl AsyncArrowWriter { /// Try to create a new Async Arrow Writer. /// - /// `buffer_flush_threshold` will be used to trigger flush of the inner buffer. + /// `buffer_size` determines the initial size of the intermediate buffer. + /// + /// The intermediate buffer will automatically be resized if necessary + /// + /// [`Self::write`] will flush this intermediate buffer if it is at least + /// half full pub fn try_new( writer: W, arrow_schema: SchemaRef, - buffer_flush_threshold: usize, + buffer_size: usize, props: Option, ) -> Result { - let shared_buffer = SharedBuffer::default(); + let shared_buffer = SharedBuffer::new(buffer_size); let sync_writer = ArrowWriter::try_new(shared_buffer.clone(), arrow_schema, props)?; @@ -103,22 +102,16 @@ impl AsyncArrowWriter { sync_writer, async_writer: writer, shared_buffer, - buffer_flush_threshold, }) } /// Enqueues the provided `RecordBatch` to be written /// /// After every sync write by the inner [ArrowWriter], the inner buffer will be - /// checked and flush if threshold is reached. + /// checked and flush if at least half full pub async fn write(&mut self, batch: &RecordBatch) -> Result<()> { self.sync_writer.write(batch)?; - Self::try_flush( - &self.shared_buffer, - &mut self.async_writer, - self.buffer_flush_threshold, - ) - .await + Self::try_flush(&mut self.shared_buffer, &mut self.async_writer, false).await } /// Append [`KeyValue`] metadata in addition to those in [`WriterProperties`] @@ -135,7 +128,7 @@ impl AsyncArrowWriter { let metadata = self.sync_writer.close()?; // Force to flush the remaining data. - Self::try_flush(&self.shared_buffer, &mut self.async_writer, 0).await?; + Self::try_flush(&mut self.shared_buffer, &mut self.async_writer, true).await?; Ok(metadata) } @@ -143,24 +136,21 @@ impl AsyncArrowWriter { /// Flush the data in the [`SharedBuffer`] into the `async_writer` if its size /// exceeds the threshold. async fn try_flush( - shared_buffer: &SharedBuffer, + shared_buffer: &mut SharedBuffer, async_writer: &mut W, - threshold: usize, + force: bool, ) -> Result<()> { - let mut buffer = { - let mut buffer = shared_buffer.buffer.lock().unwrap(); - - if buffer.is_empty() || buffer.len() < threshold { - // no need to flush - return Ok(()); - } - std::mem::take(&mut *buffer) - }; + let mut buffer = shared_buffer.buffer.try_lock().unwrap(); + if !force && buffer.len() < buffer.capacity() / 2 { + // no need to flush + return Ok(()); + } async_writer - .write(&buffer) + .write(buffer.as_slice()) .await .map_err(|e| ParquetError::External(Box::new(e)))?; + async_writer .flush() .await @@ -168,7 +158,6 @@ impl AsyncArrowWriter { // reuse the buffer. buffer.clear(); - *shared_buffer.buffer.lock().unwrap() = buffer; Ok(()) } @@ -176,23 +165,31 @@ impl AsyncArrowWriter { /// A buffer with interior mutability shared by the [`ArrowWriter`] and /// [`AsyncArrowWriter`]. -#[derive(Clone, Default)] +#[derive(Clone)] struct SharedBuffer { /// The inner buffer for reading and writing /// /// The lock is used to obtain internal mutability, so no worry about the /// lock contention. - buffer: Arc>>, + buffer: Arc>>, +} + +impl SharedBuffer { + pub fn new(capacity: usize) -> Self { + Self { + buffer: Arc::new(futures::lock::Mutex::new(Vec::with_capacity(capacity))), + } + } } impl Write for SharedBuffer { fn write(&mut self, buf: &[u8]) -> std::io::Result { - let mut buffer = self.buffer.lock().unwrap(); + let mut buffer = self.buffer.try_lock().unwrap(); Write::write(&mut *buffer, buf) } fn flush(&mut self) -> std::io::Result<()> { - let mut buffer = self.buffer.lock().unwrap(); + let mut buffer = self.buffer.try_lock().unwrap(); Write::flush(&mut *buffer) } } @@ -342,7 +339,7 @@ mod tests { }; let test_buffer_flush_thresholds = - vec![0, 1024, 40 * 1024, 50 * 1024, 100 * 1024, usize::MAX]; + vec![0, 1024, 40 * 1024, 50 * 1024, 100 * 1024]; for buffer_flush_threshold in test_buffer_flush_thresholds { let reader = get_test_reader(); @@ -354,7 +351,7 @@ mod tests { let mut async_writer = AsyncArrowWriter::try_new( &mut test_async_sink, reader.schema(), - buffer_flush_threshold, + buffer_flush_threshold * 2, Some(write_props.clone()), ) .unwrap();