Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Amortized intermediate allocations in IPC writer #1362

Merged
merged 4 commits into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/io/flight/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub fn serialize_schema_to_info(
};

let mut schema = vec![];
write::common_sync::write_message(&mut schema, encoded_data)?;
write::common_sync::write_message(&mut schema, &encoded_data)?;
Ok(schema)
}

Expand Down
1 change: 1 addition & 0 deletions src/io/ipc/append/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ impl<R: Read + Seek + Write> FileWriter<R> {
dictionaries,
cannot_replace: true,
},
encoded_message: Default::default(),
})
}
}
4 changes: 2 additions & 2 deletions src/io/ipc/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ pub fn compress_zstd(input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
}

#[cfg(not(feature = "io_ipc_compression"))]
pub fn compress_lz4(_input_buf: &[u8], _output_buf: &mut Vec<u8>) -> Result<()> {
pub fn compress_lz4(_input_buf: &[u8], _output_buf: &[u8]) -> Result<()> {
use crate::error::Error;
Err(Error::OutOfSpec("The crate was compiled without IPC compression. Use `io_ipc_compression` to write compressed IPC.".to_string()))
}

#[cfg(not(feature = "io_ipc_compression"))]
pub fn compress_zstd(_input_buf: &[u8], _output_buf: &mut Vec<u8>) -> Result<()> {
pub fn compress_zstd(_input_buf: &[u8], _output_buf: &[u8]) -> Result<()> {
use crate::error::Error;
Err(Error::OutOfSpec("The crate was compiled without IPC compression. Use `io_ipc_compression` to write compressed IPC.".to_string()))
}
Expand Down
42 changes: 32 additions & 10 deletions src/io/ipc/write/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,25 @@ pub fn encode_chunk(
dictionary_tracker: &mut DictionaryTracker,
options: &WriteOptions,
) -> Result<(Vec<EncodedData>, EncodedData)> {
let mut encoded_message = EncodedData::default();
let encoded_dictionaries = encode_chunk_amortized(
chunk,
fields,
dictionary_tracker,
options,
&mut encoded_message,
)?;
Ok((encoded_dictionaries, encoded_message))
}

// Amortizes `EncodedData` allocation.
pub fn encode_chunk_amortized(
chunk: &Chunk<Box<dyn Array>>,
fields: &[IpcField],
dictionary_tracker: &mut DictionaryTracker,
options: &WriteOptions,
encoded_message: &mut EncodedData,
) -> Result<Vec<EncodedData>> {
let mut encoded_dictionaries = vec![];

for (field, array) in fields.iter().zip(chunk.as_ref()) {
Expand All @@ -189,9 +208,9 @@ pub fn encode_chunk(
)?;
}

let encoded_message = chunk_to_bytes(chunk, options);
chunk_to_bytes_amortized(chunk, options, encoded_message);

Ok((encoded_dictionaries, encoded_message))
Ok(encoded_dictionaries)
}

fn serialize_compression(
Expand All @@ -213,10 +232,16 @@ fn serialize_compression(

/// Write [`Chunk`] into two sets of bytes, one for the header (ipc::Schema::Message) and the
/// other for the batch's data
fn chunk_to_bytes(chunk: &Chunk<Box<dyn Array>>, options: &WriteOptions) -> EncodedData {
fn chunk_to_bytes_amortized(
chunk: &Chunk<Box<dyn Array>>,
options: &WriteOptions,
encoded_message: &mut EncodedData,
) {
let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
let mut arrow_data: Vec<u8> = vec![];
let mut arrow_data = std::mem::take(&mut encoded_message.arrow_data);
arrow_data.clear();

let mut offset = 0;
for array in chunk.arrays() {
write(
Expand Down Expand Up @@ -248,11 +273,8 @@ fn chunk_to_bytes(chunk: &Chunk<Box<dyn Array>>, options: &WriteOptions) -> Enco

let mut builder = Builder::new();
let ipc_message = builder.finish(&message, None);

EncodedData {
ipc_message: ipc_message.to_vec(),
arrow_data,
}
encoded_message.ipc_message = ipc_message.to_vec();
encoded_message.arrow_data = arrow_data
}

/// Write dictionary values into two sets of bytes, one for the header (ipc::Schema::Message) and the
Expand Down Expand Up @@ -360,7 +382,7 @@ impl DictionaryTracker {
}

/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data
#[derive(Debug)]
#[derive(Debug, Default)]
pub struct EncodedData {
/// An encoded ipc::Schema::Message
pub ipc_message: Vec<u8>,
Expand Down
10 changes: 6 additions & 4 deletions src/io/ipc/write/common_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use super::common::pad_to_64;
use super::common::EncodedData;

/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written
pub fn write_message<W: Write>(writer: &mut W, encoded: EncodedData) -> Result<(usize, usize)> {
pub fn write_message<W: Write>(writer: &mut W, encoded: &EncodedData) -> Result<(usize, usize)> {
let arrow_data_len = encoded.arrow_data.len();

let a = 8 - 1;
let buffer = encoded.ipc_message;
let buffer = &encoded.ipc_message;
let flatbuf_size = buffer.len();
let prefix_size = 8;
let aligned_size = (flatbuf_size + prefix_size + a) & !a;
Expand All @@ -21,10 +21,12 @@ pub fn write_message<W: Write>(writer: &mut W, encoded: EncodedData) -> Result<(

// write the flatbuf
if flatbuf_size > 0 {
writer.write_all(&buffer)?;
writer.write_all(buffer)?;
}
// write padding
writer.write_all(&vec![0; padding_bytes])?;
// aligned to a 8 byte boundary, so maximum is [u8;8]
const PADDING_MAX: [u8; 8] = [0u8; 8];
writer.write_all(&PADDING_MAX[..padding_bytes])?;

// write arrow data
let body_len = if arrow_data_len > 0 {
Expand Down
6 changes: 3 additions & 3 deletions src/io/ipc/write/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl<W: Write> StreamWriter<W> {
ipc_message: schema_to_bytes(schema, self.ipc_fields.as_ref().unwrap()),
arrow_data: vec![],
};
write_message(&mut self.writer, encoded_message)?;
write_message(&mut self.writer, &encoded_message)?;
Ok(())
}

Expand Down Expand Up @@ -91,10 +91,10 @@ impl<W: Write> StreamWriter<W> {
)?;

for encoded_dictionary in encoded_dictionaries {
write_message(&mut self.writer, encoded_dictionary)?;
write_message(&mut self.writer, &encoded_dictionary)?;
}

write_message(&mut self.writer, encoded_message)?;
write_message(&mut self.writer, &encoded_message)?;
Ok(())
}

Expand Down
27 changes: 21 additions & 6 deletions src/io/ipc/write/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use arrow_format::ipc::planus::Builder;
use super::{
super::IpcField,
super::ARROW_MAGIC,
common::{encode_chunk, DictionaryTracker, EncodedData, WriteOptions},
common::{DictionaryTracker, EncodedData, WriteOptions},
common_sync::{write_continuation, write_message},
default_ipc_fields, schema, schema_to_bytes,
};
Expand All @@ -14,6 +14,7 @@ use crate::array::Array;
use crate::chunk::Chunk;
use crate::datatypes::*;
use crate::error::{Error, Result};
use crate::io::ipc::write::common::encode_chunk_amortized;

#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum State {
Expand Down Expand Up @@ -41,6 +42,8 @@ pub struct FileWriter<W: Write> {
pub(crate) state: State,
/// Keeps track of dictionaries that have been written
pub(crate) dictionary_tracker: DictionaryTracker,
/// Buffer/scratch that is reused between writes
pub(crate) encoded_message: EncodedData,
}

impl<W: Write> FileWriter<W> {
Expand Down Expand Up @@ -83,6 +86,7 @@ impl<W: Write> FileWriter<W> {
dictionaries: Default::default(),
cannot_replace: true,
},
encoded_message: Default::default(),
}
}

Expand All @@ -91,6 +95,17 @@ impl<W: Write> FileWriter<W> {
self.writer
}

/// Get the inner memory scratches so they can be reused in a new writer.
/// This can be utilized to save memory allocations for performance reasons.
pub fn get_scratches(&mut self) -> EncodedData {
std::mem::take(&mut self.encoded_message)
}
/// Set the inner memory scratches so they can be reused in a new writer.
/// This can be utilized to save memory allocations for performance reasons.
pub fn set_scratches(&mut self, scratches: EncodedData) {
self.encoded_message = scratches;
}

/// Writes the header and first (schema) message to the file.
/// # Errors
/// Errors if the file has been started or has finished.
Expand All @@ -109,7 +124,7 @@ impl<W: Write> FileWriter<W> {
arrow_data: vec![],
};

let (meta, data) = write_message(&mut self.writer, encoded_message)?;
let (meta, data) = write_message(&mut self.writer, &encoded_message)?;
self.block_offsets += meta + data + 8; // 8 <=> arrow magic + 2 bytes for alignment
self.state = State::Started;
Ok(())
Expand All @@ -132,17 +147,17 @@ impl<W: Write> FileWriter<W> {
} else {
self.ipc_fields.as_ref()
};

let (encoded_dictionaries, encoded_message) = encode_chunk(
let encoded_dictionaries = encode_chunk_amortized(
chunk,
ipc_fields,
&mut self.dictionary_tracker,
&self.options,
&mut self.encoded_message,
)?;

// add all dictionaries
for encoded_dictionary in encoded_dictionaries {
let (meta, data) = write_message(&mut self.writer, encoded_dictionary)?;
let (meta, data) = write_message(&mut self.writer, &encoded_dictionary)?;

let block = arrow_format::ipc::Block {
offset: self.block_offsets as i64,
Expand All @@ -153,7 +168,7 @@ impl<W: Write> FileWriter<W> {
self.block_offsets += meta + data;
}

let (meta, data) = write_message(&mut self.writer, encoded_message)?;
let (meta, data) = write_message(&mut self.writer, &self.encoded_message)?;
// add a record block for the footer
let block = arrow_format::ipc::Block {
offset: self.block_offsets as i64,
Expand Down
3 changes: 1 addition & 2 deletions src/temporal_conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,7 @@ fn chrono_tz_utf_to_timestamp_ns<O: Offset>(
timezone: String,
) -> Result<PrimitiveArray<i64>> {
Err(Error::InvalidArgumentError(format!(
"timezone \"{}\" cannot be parsed (feature chrono-tz is not active)",
timezone
"timezone \"{timezone}\" cannot be parsed (feature chrono-tz is not active)",
)))
}

Expand Down