Skip to content

Commit

Permalink
Add RecordBatchWriter trait and implement it for CSV, JSON, IPC and P…
Browse files Browse the repository at this point in the history
…arquet (#4206)

Co-authored-by: alexandreyc <alexandre@crayssac.net>
  • Loading branch information
alexandreyc and alexandreyc authored May 12, 2023
1 parent 144528f commit 0190408
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 26 deletions.
1 change: 1 addition & 0 deletions arrow-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ pub use array::*;
mod record_batch;
pub use record_batch::{
RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader,
RecordBatchWriter,
};

mod arithmetic;
Expand Down
6 changes: 6 additions & 0 deletions arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
}
}

/// Trait for types that can write `RecordBatch`'s.
pub trait RecordBatchWriter {
/// Write a single batch to the writer.
fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
}

/// A two-dimensional batch of column-oriented data with a defined
/// [schema](arrow_schema::Schema).
///
Expand Down
6 changes: 6 additions & 0 deletions arrow-csv/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ impl<W: Write> Writer<W> {
}
}

impl<W: Write> RecordBatchWriter for Writer<W> {
fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
self.write(batch)
}
}

/// A CSV writer builder
#[derive(Clone, Debug)]
pub struct WriterBuilder {
Expand Down
12 changes: 12 additions & 0 deletions arrow-ipc/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,12 @@ impl<W: Write> FileWriter<W> {
}
}

impl<W: Write> RecordBatchWriter for FileWriter<W> {
fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
self.write(batch)
}
}

pub struct StreamWriter<W: Write> {
/// The object to write to
writer: BufWriter<W>,
Expand Down Expand Up @@ -991,6 +997,12 @@ impl<W: Write> StreamWriter<W> {
}
}

impl<W: Write> RecordBatchWriter for StreamWriter<W> {
fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
self.write(batch)
}
}

/// Stores the encoded data, which is an crate::Message, and optional Arrow data
pub struct EncodedData {
/// An encoded crate::Message
Expand Down
56 changes: 33 additions & 23 deletions arrow-json/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
//! let a = Int32Array::from(vec![1, 2, 3]);
//! let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
//!
//! let json_rows = arrow_json::writer::record_batches_to_json_rows(&[batch]).unwrap();
//! let json_rows = arrow_json::writer::record_batches_to_json_rows(&[&batch]).unwrap();
//! assert_eq!(
//! serde_json::Value::Object(json_rows[1].clone()),
//! serde_json::json!({"a": 2}),
Expand All @@ -59,7 +59,7 @@
//! // Write the record batch out as JSON
//! let buf = Vec::new();
//! let mut writer = arrow_json::LineDelimitedWriter::new(buf);
//! writer.write_batches(&vec![batch]).unwrap();
//! writer.write_batches(&vec![&batch]).unwrap();
//! writer.finish().unwrap();
//!
//! // Get the underlying buffer back,
Expand All @@ -85,7 +85,7 @@
//! // Write the record batch out as a JSON array
//! let buf = Vec::new();
//! let mut writer = arrow_json::ArrayWriter::new(buf);
//! writer.write_batches(&vec![batch]).unwrap();
//! writer.write_batches(&vec![&batch]).unwrap();
//! writer.finish().unwrap();
//!
//! // Get the underlying buffer back,
Expand Down Expand Up @@ -390,7 +390,7 @@ fn set_column_for_json_rows(
/// Converts an arrow [`RecordBatch`] into a `Vec` of Serde JSON
/// [`JsonMap`]s (objects)
pub fn record_batches_to_json_rows(
batches: &[RecordBatch],
batches: &[&RecordBatch],
) -> Result<Vec<JsonMap<String, Value>>, ArrowError> {
let mut rows: Vec<JsonMap<String, Value>> = iter::repeat(JsonMap::new())
.take(batches.iter().map(|b| b.num_rows()).sum())
Expand Down Expand Up @@ -554,15 +554,15 @@ where
}

/// Convert the `RecordBatch` into JSON rows, and write them to the output
pub fn write(&mut self, batch: RecordBatch) -> Result<(), ArrowError> {
pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
for row in record_batches_to_json_rows(&[batch])? {
self.write_row(&Value::Object(row))?;
}
Ok(())
}

/// Convert the [`RecordBatch`] into JSON rows, and write them to the output
pub fn write_batches(&mut self, batches: &[RecordBatch]) -> Result<(), ArrowError> {
pub fn write_batches(&mut self, batches: &[&RecordBatch]) -> Result<(), ArrowError> {
for row in record_batches_to_json_rows(batches)? {
self.write_row(&Value::Object(row))?;
}
Expand All @@ -586,6 +586,16 @@ where
}
}

impl<W, F> RecordBatchWriter for Writer<W, F>
where
W: Write,
F: JsonFormat,
{
fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
self.write(batch)
}
}

#[cfg(test)]
mod tests {
use std::fs::{read_to_string, File};
Expand Down Expand Up @@ -631,7 +641,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

assert_json_eq(
Expand Down Expand Up @@ -662,7 +672,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

assert_json_eq(
Expand Down Expand Up @@ -704,7 +714,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

assert_json_eq(
Expand Down Expand Up @@ -759,7 +769,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

assert_json_eq(
Expand Down Expand Up @@ -818,7 +828,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

assert_json_eq(
Expand Down Expand Up @@ -864,7 +874,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

assert_json_eq(
Expand Down Expand Up @@ -907,7 +917,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

assert_json_eq(
Expand Down Expand Up @@ -950,7 +960,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

assert_json_eq(
Expand Down Expand Up @@ -1010,7 +1020,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

assert_json_eq(
Expand Down Expand Up @@ -1053,7 +1063,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

assert_json_eq(
Expand Down Expand Up @@ -1113,7 +1123,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

assert_json_eq(
Expand Down Expand Up @@ -1192,7 +1202,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

assert_json_eq(
Expand All @@ -1217,7 +1227,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

let result = String::from_utf8(buf).unwrap();
Expand Down Expand Up @@ -1315,7 +1325,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

// NOTE: The last value should technically be {"list": [null]} but it appears
Expand Down Expand Up @@ -1378,7 +1388,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write_batches(&[batch]).unwrap();
writer.write_batches(&[&batch]).unwrap();
}

assert_json_eq(
Expand Down Expand Up @@ -1408,7 +1418,7 @@ mod tests {
let mut buf = Vec::new();
{
let mut writer = LineDelimitedWriter::new(&mut buf);
writer.write(batch).unwrap();
writer.write(&batch).unwrap();
}

let result = String::from_utf8(buf).unwrap();
Expand Down Expand Up @@ -1445,7 +1455,7 @@ mod tests {
let batch = reader.next().unwrap().unwrap();

// test batches = an empty batch + 2 same batches, finally result should be eq to 2 same batches
let batches = [RecordBatch::new_empty(schema), batch.clone(), batch];
let batches = [&RecordBatch::new_empty(schema), &batch, &batch];

let mut buf = Vec::new();
{
Expand Down
2 changes: 1 addition & 1 deletion arrow/benches/json_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ fn large_bench_primitive(c: &mut Criterion) {
.unwrap();

let mut out = Vec::with_capacity(1024);
LineDelimitedWriter::new(&mut out).write(batch).unwrap();
LineDelimitedWriter::new(&mut out).write(&batch).unwrap();

let json = std::str::from_utf8(&out).unwrap();
do_bench(c, "large_bench_primitive", json, schema)
Expand Down
10 changes: 8 additions & 2 deletions parquet/src/arrow/arrow_writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use std::sync::Arc;

use arrow_array::cast::AsArray;
use arrow_array::types::{Decimal128Type, Int32Type, Int64Type, UInt32Type, UInt64Type};
use arrow_array::{types, Array, ArrayRef, RecordBatch};
use arrow_schema::{DataType as ArrowDataType, IntervalUnit, SchemaRef};
use arrow_array::{types, Array, ArrayRef, RecordBatch, RecordBatchWriter};
use arrow_schema::{ArrowError, DataType as ArrowDataType, IntervalUnit, SchemaRef};

use super::schema::{
add_encoded_arrow_schema_to_metadata, arrow_to_parquet_schema,
Expand Down Expand Up @@ -246,6 +246,12 @@ impl<W: Write> ArrowWriter<W> {
}
}

impl<W: Write> RecordBatchWriter for ArrowWriter<W> {
fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
self.write(batch).map_err(|e| e.into())
}
}

fn write_leaves<W: Write>(
row_group_writer: &mut SerializedRowGroupWriter<'_, W>,
arrays: &[ArrayRef],
Expand Down

0 comments on commit 0190408

Please sign in to comment.