diff --git a/rust/Cargo.toml b/rust/Cargo.toml index a92c0d743e..c093811c52 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -16,7 +16,10 @@ edition = "2021" arrow = { version = "39", optional = true } arrow-array = { version = "39", optional = true } arrow-cast = { version = "39", optional = true } +arrow-ord = { version = "39", optional = true } +arrow-row = { version = "39", optional = true } arrow-schema = { version = "39", optional = true } +arrow-select = { version = "39", optional = true } async-trait = "0.1" bytes = "1" chrono = { version = "0.4.22", default-features = false, features = ["clock"] } @@ -102,7 +105,7 @@ glibc_version = { path = "../glibc_version", version = "0.1" } [features] azure = ["object_store/azure"] -arrow = ["dep:arrow", "arrow-array", "arrow-cast", "arrow-schema"] +arrow = ["dep:arrow", "arrow-array", "arrow-cast", "arrow-ord", "arrow-row", "arrow-schema", "arrow-select"] default = ["arrow", "parquet"] datafusion = [ "dep:datafusion", diff --git a/rust/src/writer/record_batch.rs b/rust/src/writer/record_batch.rs index 6111d06d28..fc9421599c 100644 --- a/rust/src/writer/record_batch.rs +++ b/rust/src/writer/record_batch.rs @@ -26,31 +26,30 @@ //! })) //! } //! ``` -use super::{ - stats::{create_add, NullCounts}, - utils::{ - arrow_schema_without_partitions, next_data_path, record_batch_without_partitions, - stringified_partition_value, PartitionPath, - }, - DeltaWriter, DeltaWriterError, -}; -use crate::builder::DeltaTableBuilder; -use crate::writer::stats::apply_null_counts; -use crate::writer::utils::ShareableBuffer; -use crate::DeltaTableError; -use crate::{action::Add, storage::DeltaObjectStore, DeltaTable, DeltaTableMetaData, Schema}; -use arrow::array::{Array, UInt32Array}; -use arrow::compute::{lexicographical_partition_ranges, lexsort_to_indices, take, SortColumn}; -use arrow::datatypes::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; -use arrow::error::ArrowError; -use arrow::record_batch::RecordBatch; + +use std::collections::HashMap; +use std::convert::TryFrom; +use std::sync::Arc; + +use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; +use arrow_ord::{partition::lexicographical_partition_ranges, sort::SortColumn}; +use arrow_row::{RowConverter, SortField}; +use arrow_schema::{ArrowError, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; +use arrow_select::take::take; use bytes::Bytes; use object_store::ObjectStore; use parquet::{arrow::ArrowWriter, errors::ParquetError}; use parquet::{basic::Compression, file::properties::WriterProperties}; -use std::collections::HashMap; -use std::convert::TryFrom; -use std::sync::Arc; + +use super::stats::{create_add, NullCounts}; +use super::utils::{ + arrow_schema_without_partitions, next_data_path, record_batch_without_partitions, + stringified_partition_value, PartitionPath, +}; +use super::{DeltaTableError, DeltaWriter, DeltaWriterError}; +use crate::builder::DeltaTableBuilder; +use crate::writer::{stats::apply_null_counts, utils::ShareableBuffer}; +use crate::{action::Add, storage::DeltaObjectStore, DeltaTable, DeltaTableMetaData, Schema}; /// Writes messages to a delta lake table. pub struct RecordBatchWriter { @@ -354,24 +353,18 @@ pub(crate) fn divide_by_partition_values( let schema = values.schema(); - // collect all columns in order relevant for partitioning - let sort_columns = partition_columns - .clone() - .into_iter() - .map(|col| { - Ok(SortColumn { - values: values.column(schema.index_of(&col)?).clone(), - options: None, - }) - }) + let projection = partition_columns + .iter() + .map(|n| Ok(schema.index_of(n)?)) .collect::, DeltaWriterError>>()?; + let sort_columns = values.project(&projection)?; - let indices = lexsort_to_indices(sort_columns.as_slice(), None)?; - let sorted_partition_columns = sort_columns + let indices = lexsort_to_indices(sort_columns.columns()); + let sorted_partition_columns = partition_columns .iter() .map(|c| { Ok(SortColumn { - values: take(c.values.as_ref(), &indices, None)?, + values: take(values.column(schema.index_of(c)?), &indices, None)?, options: None, }) }) @@ -410,6 +403,18 @@ pub(crate) fn divide_by_partition_values( Ok(partitions) } +fn lexsort_to_indices(arrays: &[ArrayRef]) -> UInt32Array { + let fields = arrays + .iter() + .map(|a| SortField::new(a.data_type().clone())) + .collect(); + let mut converter = RowConverter::new(fields).unwrap(); + let rows = converter.convert_columns(arrays).unwrap(); + let mut sort: Vec<_> = rows.iter().enumerate().collect(); + sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); + UInt32Array::from_iter_values(sort.iter().map(|(i, _)| *i as u32)) +} + #[cfg(test)] mod tests { use super::*; @@ -506,7 +511,6 @@ mod tests { let mut writer = RecordBatchWriter::for_table(&table).unwrap(); let partitions = writer.divide_by_partition_values(&batch).unwrap(); - println!("partitions: {:?}", partitions); let expected_keys = vec![ String::from("modified=2021-02-01"),