This repository has been archived by the owner on Feb 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 223
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added
async
writer of the Arrow stream.
- Loading branch information
1 parent
83d828c
commit 03c20e0
Showing
10 changed files
with
232 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
use futures::AsyncWrite; | ||
use futures::AsyncWriteExt; | ||
|
||
use crate::error::{ArrowError, Result}; | ||
|
||
use super::super::CONTINUATION_MARKER; | ||
use super::common::pad_to_8; | ||
use super::common::EncodedData; | ||
|
||
/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written | ||
pub async fn write_message<W: AsyncWrite + Unpin + Send>( | ||
writer: &mut W, | ||
encoded: EncodedData, | ||
) -> Result<(usize, usize)> { | ||
let arrow_data_len = encoded.arrow_data.len(); | ||
if arrow_data_len % 8 != 0 { | ||
return Err(ArrowError::Ipc("Arrow data not aligned".to_string())); | ||
} | ||
|
||
let a = 8 - 1; | ||
let buffer = encoded.ipc_message; | ||
let flatbuf_size = buffer.len(); | ||
let prefix_size = 8; | ||
let aligned_size = (flatbuf_size + prefix_size + a) & !a; | ||
let padding_bytes = aligned_size - flatbuf_size - prefix_size; | ||
|
||
write_continuation(writer, (aligned_size - prefix_size) as i32).await?; | ||
|
||
// write the flatbuf | ||
if flatbuf_size > 0 { | ||
writer.write_all(&buffer).await?; | ||
} | ||
// write padding | ||
writer.write_all(&vec![0; padding_bytes]).await?; | ||
|
||
// write arrow data | ||
let body_len = if arrow_data_len > 0 { | ||
write_body_buffers(writer, &encoded.arrow_data).await? | ||
} else { | ||
0 | ||
}; | ||
|
||
Ok((aligned_size, body_len)) | ||
} | ||
|
||
/// Write a record batch to the writer, writing the message size before the message | ||
/// if the record batch is being written to a stream | ||
pub async fn write_continuation<W: AsyncWrite + Unpin + Send>( | ||
writer: &mut W, | ||
total_len: i32, | ||
) -> Result<usize> { | ||
writer.write_all(&CONTINUATION_MARKER).await?; | ||
writer.write_all(&total_len.to_le_bytes()[..]).await?; | ||
writer.flush().await?; | ||
Ok(8) | ||
} | ||
|
||
async fn write_body_buffers<W: AsyncWrite + Unpin + Send>( | ||
mut writer: W, | ||
data: &[u8], | ||
) -> Result<usize> { | ||
let len = data.len(); | ||
let pad_len = pad_to_8(data.len()); | ||
let total_len = len + pad_len; | ||
|
||
// write body buffer | ||
writer.write_all(data).await?; | ||
if pad_len > 0 { | ||
writer.write_all(&vec![0u8; pad_len][..]).await?; | ||
} | ||
|
||
writer.flush().await?; | ||
Ok(total_len) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
use std::io::Write; | ||
|
||
use crate::error::{ArrowError, Result}; | ||
|
||
use super::super::CONTINUATION_MARKER; | ||
use super::common::pad_to_8; | ||
use super::common::EncodedData; | ||
|
||
/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written | ||
pub fn write_message<W: Write>(writer: &mut W, encoded: EncodedData) -> Result<(usize, usize)> { | ||
let arrow_data_len = encoded.arrow_data.len(); | ||
if arrow_data_len % 8 != 0 { | ||
return Err(ArrowError::Ipc("Arrow data not aligned".to_string())); | ||
} | ||
|
||
let a = 8 - 1; | ||
let buffer = encoded.ipc_message; | ||
let flatbuf_size = buffer.len(); | ||
let prefix_size = 8; | ||
let aligned_size = (flatbuf_size + prefix_size + a) & !a; | ||
let padding_bytes = aligned_size - flatbuf_size - prefix_size; | ||
|
||
write_continuation(writer, (aligned_size - prefix_size) as i32)?; | ||
|
||
// write the flatbuf | ||
if flatbuf_size > 0 { | ||
writer.write_all(&buffer)?; | ||
} | ||
// write padding | ||
writer.write_all(&vec![0; padding_bytes])?; | ||
|
||
// write arrow data | ||
let body_len = if arrow_data_len > 0 { | ||
write_body_buffers(writer, &encoded.arrow_data)? | ||
} else { | ||
0 | ||
}; | ||
|
||
Ok((aligned_size, body_len)) | ||
} | ||
|
||
fn write_body_buffers<W: Write>(mut writer: W, data: &[u8]) -> Result<usize> { | ||
let len = data.len(); | ||
let pad_len = pad_to_8(data.len()); | ||
let total_len = len + pad_len; | ||
|
||
// write body buffer | ||
writer.write_all(data)?; | ||
if pad_len > 0 { | ||
writer.write_all(&vec![0u8; pad_len][..])?; | ||
} | ||
|
||
writer.flush()?; | ||
Ok(total_len) | ||
} | ||
|
||
/// Write a record batch to the writer, writing the message size before the message | ||
/// if the record batch is being written to a stream | ||
pub fn write_continuation<W: Write>(writer: &mut W, total_len: i32) -> Result<usize> { | ||
writer.write_all(&CONTINUATION_MARKER)?; | ||
writer.write_all(&total_len.to_le_bytes()[..])?; | ||
writer.flush()?; | ||
Ok(8) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
//! `async` writing of arrow streams | ||
use futures::AsyncWrite; | ||
|
||
use super::common::{encoded_batch, DictionaryTracker, EncodedData, WriteOptions}; | ||
use super::common_async::{write_continuation, write_message}; | ||
use super::schema_to_bytes; | ||
|
||
use crate::datatypes::*; | ||
use crate::error::{ArrowError, Result}; | ||
use crate::record_batch::RecordBatch; | ||
|
||
/// An `async` writer to the Apache Arrow stream format. | ||
pub struct StreamWriter<W: AsyncWrite + Unpin + Send> { | ||
/// The object to write to | ||
writer: W, | ||
/// IPC write options | ||
write_options: WriteOptions, | ||
/// Whether the writer footer has been written, and the writer is finished | ||
finished: bool, | ||
/// Keeps track of dictionaries that have been written | ||
dictionary_tracker: DictionaryTracker, | ||
} | ||
|
||
impl<W: AsyncWrite + Unpin + Send> StreamWriter<W> { | ||
/// Creates a new [`StreamWriter`] | ||
pub fn new(writer: W, write_options: WriteOptions) -> Self { | ||
Self { | ||
writer, | ||
write_options, | ||
finished: false, | ||
dictionary_tracker: DictionaryTracker::new(false), | ||
} | ||
} | ||
|
||
/// Starts the stream | ||
pub async fn start(&mut self, schema: &Schema) -> Result<()> { | ||
let encoded_message = EncodedData { | ||
ipc_message: schema_to_bytes(schema), | ||
arrow_data: vec![], | ||
}; | ||
write_message(&mut self.writer, encoded_message).await?; | ||
Ok(()) | ||
} | ||
|
||
/// Writes a [`RecordBatch`] to the stream | ||
pub async fn write(&mut self, batch: &RecordBatch) -> Result<()> { | ||
if self.finished { | ||
return Err(ArrowError::Ipc( | ||
"Cannot write record batch to stream writer as it is closed".to_string(), | ||
)); | ||
} | ||
|
||
// todo: move this out of the `async` since this is blocking. | ||
let (encoded_dictionaries, encoded_message) = | ||
encoded_batch(batch, &mut self.dictionary_tracker, &self.write_options)?; | ||
|
||
for encoded_dictionary in encoded_dictionaries { | ||
write_message(&mut self.writer, encoded_dictionary).await?; | ||
} | ||
|
||
write_message(&mut self.writer, encoded_message).await?; | ||
Ok(()) | ||
} | ||
|
||
/// Finishes the stream | ||
pub async fn finish(&mut self) -> Result<()> { | ||
write_continuation(&mut self.writer, 0).await?; | ||
self.finished = true; | ||
Ok(()) | ||
} | ||
|
||
/// Consumes itself, returning the inner writer. | ||
pub fn into_inner(self) -> W { | ||
self.writer | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters