From 805ea9249829f9ee259ff97b63e68f4cd747d6e9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 15 Apr 2021 16:38:51 -0400 Subject: [PATCH 1/2] ARROW-12411: [Rust] Add Builder interface for adding Arrays to record batches --- rust/arrow/src/datatypes/schema.rs | 6 +++ rust/arrow/src/record_batch.rs | 78 +++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/rust/arrow/src/datatypes/schema.rs b/rust/arrow/src/datatypes/schema.rs index ad89b29cacda6..35ee3353d62f6 100644 --- a/rust/arrow/src/datatypes/schema.rs +++ b/rust/arrow/src/datatypes/schema.rs @@ -152,6 +152,12 @@ impl Schema { }) } + /// Appends a new field to this `Schema` as a field named + /// `field_name`. + pub fn push(&mut self, field: Field) { + self.fields.push(field) + } + /// Returns an immutable reference of the vector of `Field` instances. #[inline] pub const fn fields(&self) -> &Vec { diff --git a/rust/arrow/src/record_batch.rs b/rust/arrow/src/record_batch.rs index 93abb909d020f..822ac1d6d70d1 100644 --- a/rust/arrow/src/record_batch.rs +++ b/rust/arrow/src/record_batch.rs @@ -93,7 +93,7 @@ impl RecordBatch { Ok(RecordBatch { schema, columns }) } - /// Creates a new empty [`RecordBatch`]. + /// Creates a new empty [`RecordBatch`] based on `schema`. pub fn new_empty(schema: SchemaRef) -> Self { let columns = schema .fields() @@ -103,6 +103,56 @@ impl RecordBatch { RecordBatch { schema, columns } } + /// Creates a new [`RecordBatch`] with no columns + /// + /// TODO add an code example using `append` + pub fn new() -> Self { + Self { + schema: Arc::new(Schema::empty()), + columns: Vec::new(), + } + } + + /// Appends the `field_array` array to this `RecordBatch` as a + /// field named `field_name`. + /// + /// TODO: code example + /// + /// TODO: on error, can we return `Self` in some meaningful way? + pub fn append(self, field_name: &str, field_values: ArrayRef) -> Result { + if let Some(col) = self.columns.get(0) { + if col.len() != field_values.len() { + return Err(ArrowError::InvalidArgumentError( + format!("all columns in a record batch must have the same length. expected {}, field {} had {} ", + col.len(), field_name, field_values.len()) + )); + } + } + + let Self { + schema, + mut columns, + } = self; + + // modify the schema we have if possible, otherwise copy + let mut schema = match Arc::try_unwrap(schema) { + Ok(schema) => schema, + Err(shared_schema) => shared_schema.as_ref().clone(), + }; + + let nullable = field_values.null_count() > 0; + schema.push(Field::new( + field_name, + field_values.data_type().clone(), + nullable, + )); + let schema = Arc::new(schema); + + columns.push(field_values); + + Ok(Self { schema, columns }) + } + /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error /// if any validation check fails. fn validate_new_batch( @@ -337,6 +387,32 @@ mod tests { assert_eq!(5, record_batch.column(1).data().len()); } + #[test] + fn create_record_batch_builder() { + let a = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4), Some(5)])); + let b = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); + + let record_batch = RecordBatch::new() + .append("a", a) + .unwrap() + .append("b", b) + .unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, false), + ]); + + assert_eq!(record_batch.schema().as_ref(), &expected_schema); + + assert_eq!(5, record_batch.num_rows()); + assert_eq!(2, record_batch.num_columns()); + assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type()); + assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type()); + assert_eq!(5, record_batch.column(0).data().len()); + assert_eq!(5, record_batch.column(1).data().len()); + } + #[test] fn create_record_batch_schema_mismatch() { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); From 5438428016314b3f53265241e428c870425f09bc Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 15 Apr 2021 16:48:38 -0400 Subject: [PATCH 2/2] appease clippy --- rust/arrow/src/record_batch.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/rust/arrow/src/record_batch.rs b/rust/arrow/src/record_batch.rs index 822ac1d6d70d1..1d1a583ca710e 100644 --- a/rust/arrow/src/record_batch.rs +++ b/rust/arrow/src/record_batch.rs @@ -295,6 +295,12 @@ impl RecordBatch { } } +impl Default for RecordBatch { + fn default() -> Self { + Self::new() + } +} + /// Options that control the behaviour used when creating a [`RecordBatch`]. #[derive(Debug)] pub struct RecordBatchOptions { @@ -389,7 +395,13 @@ mod tests { #[test] fn create_record_batch_builder() { - let a = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4), Some(5)])); + let a = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(4), + Some(5), + ])); let b = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); let record_batch = RecordBatch::new()