Skip to content

Commit

Permalink
Add ArrayWriter indirection (#1764) (#2091)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold authored Jul 21, 2022
1 parent be0d34d commit 576069a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 35 deletions.
65 changes: 45 additions & 20 deletions parquet/src/arrow/arrow_writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use super::schema::{
decimal_length_from_precision,
};

use crate::column::writer::ColumnWriter;
use crate::column::writer::{get_column_writer, ColumnWriter};
use crate::errors::{ParquetError, Result};
use crate::file::metadata::RowGroupMetaDataPtr;
use crate::file::properties::WriterProperties;
Expand All @@ -43,6 +43,44 @@ use levels::{calculate_array_levels, LevelInfo};

mod levels;

/// An object-safe API for writing an [`ArrayRef`]
trait ArrayWriter {
fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()>;

fn close(&mut self) -> Result<()>;
}

/// Fallback implementation for writing an [`ArrayRef`] that uses [`SerializedColumnWriter`]
struct ColumnArrayWriter<'a>(Option<SerializedColumnWriter<'a>>);

impl<'a> ArrayWriter for ColumnArrayWriter<'a> {
fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()> {
write_leaf(self.0.as_mut().unwrap().untyped(), array, levels)?;
Ok(())
}

fn close(&mut self) -> Result<()> {
self.0.take().unwrap().close()
}
}

fn get_writer<'a, W: Write>(
row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>,
) -> Result<Box<dyn ArrayWriter + 'a>> {
let array_writer = row_group_writer
.next_column_with_factory(|descr, props, page_writer, on_close| {
// TODO: Special case array readers (#1764)

let column_writer = get_column_writer(descr, props.clone(), page_writer);
let serialized_writer =
SerializedColumnWriter::new(column_writer, Some(on_close));

Ok(Box::new(ColumnArrayWriter(Some(serialized_writer))))
})?
.expect("Unable to get column writer");
Ok(array_writer)
}

/// Arrow writer
///
/// Writes Arrow `RecordBatch`es to a Parquet writer, buffering up `RecordBatch` in order
Expand Down Expand Up @@ -229,17 +267,6 @@ impl<W: Write> ArrowWriter<W> {
}
}

/// Convenience method to get the next ColumnWriter from the RowGroupWriter
#[inline]
fn get_col_writer<'a, W: Write>(
row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>,
) -> Result<SerializedColumnWriter<'a>> {
let col_writer = row_group_writer
.next_column()?
.expect("Unable to get column writer");
Ok(col_writer)
}

fn write_leaves<W: Write>(
row_group_writer: &mut SerializedRowGroupWriter<'_, W>,
arrays: &[ArrayRef],
Expand Down Expand Up @@ -277,15 +304,14 @@ fn write_leaves<W: Write>(
| ArrowDataType::LargeUtf8
| ArrowDataType::Decimal(_, _)
| ArrowDataType::FixedSizeBinary(_) => {
let mut col_writer = get_col_writer(row_group_writer)?;
let mut writer = get_writer(row_group_writer)?;
for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
write_leaf(
col_writer.untyped(),
writer.write(
array,
levels.pop().expect("Levels exhausted"),
)?;
}
col_writer.close()?;
writer.close()?;
Ok(())
}
ArrowDataType::List(_) | ArrowDataType::LargeList(_) => {
Expand Down Expand Up @@ -338,17 +364,16 @@ fn write_leaves<W: Write>(
Ok(())
}
ArrowDataType::Dictionary(_, value_type) => {
let mut col_writer = get_col_writer(row_group_writer)?;
let mut writer = get_writer(row_group_writer)?;
for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
// cast dictionary to a primitive
let array = arrow::compute::cast(array, value_type)?;
write_leaf(
col_writer.untyped(),
writer.write(
&array,
levels.pop().expect("Levels exhausted"),
)?;
}
col_writer.close()?;
writer.close()?;
Ok(())
}
ArrowDataType::Float16 => Err(ParquetError::ArrowError(
Expand Down
51 changes: 36 additions & 15 deletions parquet/src/file/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ use crate::file::{
metadata::*, properties::WriterPropertiesPtr,
statistics::to_thrift as statistics_to_thrift, FOOTER_SIZE, PARQUET_MAGIC,
};
use crate::schema::types::{self, SchemaDescPtr, SchemaDescriptor, TypePtr};
use crate::schema::types::{
self, ColumnDescPtr, SchemaDescPtr, SchemaDescriptor, TypePtr,
};
use crate::util::io::TryClone;

/// A wrapper around a [`Write`] that keeps track of the number
Expand Down Expand Up @@ -367,22 +369,26 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> {
}
}

/// Returns the next column writer, if available; otherwise returns `None`.
/// In case of any IO error or Thrift error, or if row group writer has already been
/// closed returns `Err`.
pub fn next_column(&mut self) -> Result<Option<SerializedColumnWriter<'_>>> {
/// Returns the next column writer, if available, using the factory function;
/// otherwise returns `None`.
pub(crate) fn next_column_with_factory<'b, F, C>(
&'b mut self,
factory: F,
) -> Result<Option<C>>
where
F: FnOnce(
ColumnDescPtr,
&'b WriterPropertiesPtr,
Box<dyn PageWriter + 'b>,
OnCloseColumnChunk<'b>,
) -> Result<C>,
{
self.assert_previous_writer_closed()?;

if self.column_index >= self.descr.num_columns() {
return Ok(None);
}
let page_writer = Box::new(SerializedPageWriter::new(self.buf));
let column_writer = get_column_writer(
self.descr.column(self.column_index),
self.props.clone(),
page_writer,
);
self.column_index += 1;

let total_bytes_written = &mut self.total_bytes_written;
let total_rows_written = &mut self.total_rows_written;
Expand Down Expand Up @@ -413,10 +419,25 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> {
Ok(())
};

Ok(Some(SerializedColumnWriter::new(
column_writer,
Some(Box::new(on_close)),
)))
let column = self.descr.column(self.column_index);
self.column_index += 1;

Ok(Some(factory(
column,
&self.props,
page_writer,
Box::new(on_close),
)?))
}

/// Returns the next column writer, if available; otherwise returns `None`.
/// In case of any IO error or Thrift error, or if row group writer has already been
/// closed returns `Err`.
pub fn next_column(&mut self) -> Result<Option<SerializedColumnWriter<'_>>> {
self.next_column_with_factory(|descr, props, page_writer, on_close| {
let column_writer = get_column_writer(descr, props.clone(), page_writer);
Ok(SerializedColumnWriter::new(column_writer, Some(on_close)))
})
}

/// Closes this row group writer and returns row group metadata.
Expand Down

0 comments on commit 576069a

Please sign in to comment.