diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 0abff85937..b93c27f6cb 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -443,3 +443,37 @@ def __stringify_partition_values( str_value = str(value) out.append((field, op, str_value)) return out + + def get_add_actions(self, flatten: bool = False) -> pyarrow.RecordBatch: + """ + Return a dataframe with all current add actions. + + Add actions represent the files that currently make up the table. This + data is a low-level representation parsed from the transaction log. + + :param flatten: whether to flatten the schema. Partition values columns are + given the prefix `partition.`, statistics (null_count, min, and max) are + given the prefix `null_count.`, `min.`, and `max.`, and tags the + prefix `tags.`. Nested field names are concatenated with `.`. + + :returns: a PyArrow RecordBatch containing the add action data. + + Examples: + + >>> from deltalake import DeltaTable, write_deltalake + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> write_deltalake("tmp", data, partition_by=["x"]) + >>> dt = DeltaTable("tmp") + >>> dt.get_add_actions_df().to_pandas() + path size_bytes modification_time data_change partition_values num_records null_count min max + 0 x=2/0-91820cbf-f698-45fb-886d-5d5f5669530b-0.p... 565 1970-01-20 08:40:08.071 True {'x': 2} 1 {'y': 0} {'y': 5} {'y': 5} + 1 x=3/0-91820cbf-f698-45fb-886d-5d5f5669530b-0.p... 565 1970-01-20 08:40:08.071 True {'x': 3} 1 {'y': 0} {'y': 6} {'y': 6} + 2 x=1/0-91820cbf-f698-45fb-886d-5d5f5669530b-0.p... 565 1970-01-20 08:40:08.071 True {'x': 1} 1 {'y': 0} {'y': 4} {'y': 4} + >>> dt.get_add_actions_df(flatten=True).to_pandas() + path size_bytes modification_time data_change partition.x num_records null_count.y min.y max.y + 0 x=2/0-91820cbf-f698-45fb-886d-5d5f5669530b-0.p... 565 1970-01-20 08:40:08.071 True 2 1 0 5 5 + 1 x=3/0-91820cbf-f698-45fb-886d-5d5f5669530b-0.p... 565 1970-01-20 08:40:08.071 True 3 1 0 6 6 + 2 x=1/0-91820cbf-f698-45fb-886d-5d5f5669530b-0.p... 565 1970-01-20 08:40:08.071 True 1 1 0 4 4 + """ + return self._table.get_add_actions(flatten) diff --git a/python/src/lib.rs b/python/src/lib.rs index 5664436050..6fc786be05 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -453,6 +453,15 @@ impl RawDeltaTable { Ok(()) } + + pub fn get_add_actions(&self, flatten: bool) -> PyResult> { + Ok(PyArrowType( + self._table + .get_state() + .add_actions_table(flatten) + .map_err(PyDeltaTableError::from_raw)?, + )) + } } fn convert_partition_filters<'a>( diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index bbeca64a7d..164a32c008 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -269,6 +269,45 @@ def test_history_partitioned_table_metadata(): } +@pytest.mark.parametrize("flatten", [True, False]) +def test_add_actions_table(flatten: bool): + table_path = "../rust/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + actions_df = dt.get_add_actions(flatten) + # RecordBatch doesn't have a sort_by method yet + actions_df = pa.Table.from_batches([actions_df]).sort_by("path").to_batches()[0] + + assert actions_df.num_rows == 6 + assert actions_df["path"] == pa.array( + [ + "year=2020/month=1/day=1/part-00000-8eafa330-3be9-4a39-ad78-fd13c2027c7e.c000.snappy.parquet", + "year=2020/month=2/day=3/part-00000-94d16827-f2fd-42cd-a060-f67ccc63ced9.c000.snappy.parquet", + "year=2020/month=2/day=5/part-00000-89cdd4c8-2af7-4add-8ea3-3990b2f027b5.c000.snappy.parquet", + "year=2021/month=12/day=20/part-00000-9275fdf4-3961-4184-baa0-1c8a2bb98104.c000.snappy.parquet", + "year=2021/month=12/day=4/part-00000-6dc763c0-3e8b-4d52-b19e-1f92af3fbb25.c000.snappy.parquet", + "year=2021/month=4/day=5/part-00000-c5856301-3439-4032-a6fc-22b7bc92bebb.c000.snappy.parquet", + ] + ) + assert actions_df["size_bytes"] == pa.array([414, 414, 414, 407, 414, 414]) + assert actions_df["data_change"] == pa.array([True] * 6) + assert actions_df["modification_time"] == pa.array( + [1615555646000] * 6, type=pa.timestamp("ms") + ) + + if flatten: + partition_year = actions_df["partition.year"] + partition_month = actions_df["partition.month"] + partition_day = actions_df["partition.day"] + else: + partition_year = actions_df["partition_values"].field("year") + partition_month = actions_df["partition_values"].field("month") + partition_day = actions_df["partition_values"].field("day") + + assert partition_year == pa.array(["2020"] * 3 + ["2021"] * 3) + assert partition_month == pa.array(["1", "2", "2", "12", "12", "4"]) + assert partition_day == pa.array(["1", "3", "5", "20", "4", "5"]) + + def assert_correct_files(dt: DeltaTable, partition_filters, expected_paths): assert dt.files(partition_filters) == expected_paths absolute_paths = [os.path.join(dt.table_uri, path) for path in expected_paths] diff --git a/rust/Cargo.toml b/rust/Cargo.toml index f04bd04e4d..7f823bb96c 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -17,6 +17,7 @@ chrono = { version = "0.4.22", default-features = false, features = ["clock"] } cfg-if = "1" errno = "0.2" futures = "0.3" +itertools = "0.10" lazy_static = "1" log = "0" libc = ">=0.2.90, <1" diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 280ce33710..d90692470f 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -70,6 +70,7 @@ #![deny(warnings)] #![deny(missing_docs)] +#![allow(rustdoc::invalid_html_tags)] #[cfg(all(feature = "parquet", feature = "parquet2"))] compile_error!( @@ -92,6 +93,9 @@ pub mod table_properties; pub mod table_state; pub mod time_utils; +#[cfg(all(feature = "arrow"))] +pub mod table_state_arrow; + #[cfg(all(feature = "arrow", feature = "parquet"))] pub mod checkpoints; #[cfg(all(feature = "arrow", feature = "parquet"))] diff --git a/rust/src/table_state.rs b/rust/src/table_state.rs index 373747fe49..9c8a3da089 100644 --- a/rust/src/table_state.rs +++ b/rust/src/table_state.rs @@ -384,7 +384,6 @@ impl DeltaTableState { mod tests { use super::*; use pretty_assertions::assert_eq; - use std::collections::HashMap; #[test] fn state_round_trip() { diff --git a/rust/src/table_state_arrow.rs b/rust/src/table_state_arrow.rs new file mode 100644 index 0000000000..21a3f9a7d2 --- /dev/null +++ b/rust/src/table_state_arrow.rs @@ -0,0 +1,587 @@ +//! Methods to get Delta Table state in Arrow structures +//! +//! See [crate::table_state::DeltaTableState]. + +use crate::action::{ColumnCountStat, ColumnValueStat, Stats}; +use crate::table_state::DeltaTableState; +use crate::DeltaDataTypeLong; +use crate::DeltaTableError; +use crate::SchemaDataType; +use crate::SchemaTypeStruct; +use arrow::array::{ + ArrayRef, BinaryArray, BooleanArray, Date32Array, Float64Array, Int64Array, StringArray, + StructArray, TimestampMicrosecondArray, TimestampMillisecondArray, +}; +use arrow::compute::cast; +use arrow::compute::kernels::cast_utils::Parser; +use arrow::datatypes::{DataType, Date32Type, Field, TimeUnit, TimestampMicrosecondType}; +use itertools::Itertools; +use std::borrow::Cow; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::sync::Arc; + +impl DeltaTableState { + /// Get an [arrow::record_batch::RecordBatch] containing add action data. + /// + /// # Arguments + /// + /// * `flatten` - whether to flatten the schema. Partition values columns are + /// given the prefix `partition.`, statistics (null_count, min, and max) are + /// given the prefix `null_count.`, `min.`, and `max.`, and tags the + /// prefix `tags.`. Nested field names are concatenated with `.`. + /// + /// # Data schema + /// + /// Each row represents a file that is a part of the selected tables state. + /// + /// * `path` (String): relative or absolute to a file. + /// * `size_bytes` (Int64): size of file in bytes. + /// * `modification_time` (Millisecond Timestamp): time the file was created. + /// * `data_change` (Boolean): false if data represents data moved from other files + /// in the same transaction. + /// * `partition.{partition column name}` (matches column type): value of + /// partition the file corresponds to. + /// * `null_count.{col_name}` (Int64): number of null values for column in + /// this file. + /// * `min.{col_name}` (matches column type): minimum value of column in file + /// (if available). + /// * `max.{col_name}` (matches column type): maximum value of column in file + /// (if available). + /// * `tag.{tag_key}` (String): value of a metadata tag for the file. + pub fn add_actions_table( + &self, + flatten: bool, + ) -> Result { + let mut paths = arrow::array::StringBuilder::with_capacity( + self.files().len(), + self.files().iter().map(|add| add.path.len()).sum(), + ); + for action in self.files() { + paths.append_value(&action.path); + } + + let size = self + .files() + .iter() + .map(|file| file.size) + .collect::(); + let mod_time: TimestampMillisecondArray = self + .files() + .iter() + .map(|file| file.modification_time) + .collect::>() + .into(); + let data_change = self + .files() + .iter() + .map(|file| Some(file.data_change)) + .collect::(); + + let mut arrays: Vec<(Cow, ArrayRef)> = vec![ + (Cow::Borrowed("path"), Arc::new(paths.finish())), + (Cow::Borrowed("size_bytes"), Arc::new(size)), + (Cow::Borrowed("modification_time"), Arc::new(mod_time)), + (Cow::Borrowed("data_change"), Arc::new(data_change)), + ]; + + let metadata = self.current_metadata().ok_or(DeltaTableError::NoMetadata)?; + + if !metadata.partition_columns.is_empty() { + let partition_cols_batch = self.partition_columns_as_batch(flatten)?; + arrays.extend( + partition_cols_batch + .schema() + .fields + .iter() + .map(|field| Cow::Owned(field.name().clone())) + .zip(partition_cols_batch.columns().iter().map(Arc::clone)), + ) + } + + if self.files().iter().any(|add| add.stats.is_some()) { + let stats = self.stats_as_batch(flatten)?; + arrays.extend( + stats + .schema() + .fields + .iter() + .map(|field| Cow::Owned(field.name().clone())) + .zip(stats.columns().iter().map(Arc::clone)), + ); + } + + if self.files().iter().any(|add| { + add.tags + .as_ref() + .map(|tags| !tags.is_empty()) + .unwrap_or(false) + }) { + let tags = self.tags_as_batch(flatten)?; + arrays.extend( + tags.schema() + .fields + .iter() + .map(|field| Cow::Owned(field.name().clone())) + .zip(tags.columns().iter().map(Arc::clone)), + ); + } + + Ok(arrow::record_batch::RecordBatch::try_from_iter(arrays)?) + } + + fn partition_columns_as_batch( + &self, + flatten: bool, + ) -> Result { + let metadata = self.current_metadata().ok_or(DeltaTableError::NoMetadata)?; + + let partition_column_types: Vec = metadata + .partition_columns + .iter() + .map( + |name| -> Result { + let field = metadata.schema.get_field_with_name(name)?; + Ok(field.get_type().try_into()?) + }, + ) + .collect::>()?; + + // Create builder for each + let mut builders = metadata + .partition_columns + .iter() + .map(|name| { + let builder = arrow::array::StringBuilder::new(); + (name.as_str(), builder) + }) + .collect::>(); + + // Append values + for action in self.files() { + for (name, maybe_value) in action.partition_values.iter() { + if let Some(value) = maybe_value { + builders.get_mut(name.as_str()).unwrap().append_value(value); + } else { + builders.get_mut(name.as_str()).unwrap().append_null(); + } + } + } + + // Cast them to their appropriate types + let partition_columns: Vec = metadata + .partition_columns + .iter() + // Get the builders in their original order + .map(|name| builders.remove(name.as_str()).unwrap()) + .zip(partition_column_types.iter()) + .map(|(mut builder, datatype)| { + let string_arr: ArrayRef = Arc::new(builder.finish()); + Ok(cast(&string_arr, datatype)?) + }) + .collect::>()?; + + // if flatten, append columns, otherwise combine into a struct column + let partition_columns: Vec<(Cow, ArrayRef)> = if flatten { + partition_columns + .into_iter() + .zip(metadata.partition_columns.iter()) + .map(|(array, name)| { + let name: Cow = Cow::Owned(format!("partition.{}", name)); + (name, array) + }) + .collect() + } else { + let fields = partition_column_types + .into_iter() + .zip(metadata.partition_columns.iter()) + .map(|(datatype, name)| arrow::datatypes::Field::new(name, datatype, true)); + let field_arrays = fields + .zip(partition_columns.into_iter()) + .collect::>(); + if field_arrays.is_empty() { + vec![] + } else { + let arr = Arc::new(arrow::array::StructArray::from(field_arrays)); + vec![(Cow::Borrowed("partition_values"), arr)] + } + }; + + Ok(arrow::record_batch::RecordBatch::try_from_iter( + partition_columns, + )?) + } + + fn tags_as_batch( + &self, + flatten: bool, + ) -> Result { + let tag_keys: HashSet<&str> = self + .files() + .iter() + .flat_map(|add| add.tags.as_ref().map(|tags| tags.keys())) + .flatten() + .map(|key| key.as_str()) + .collect(); + let mut builder_map: HashMap<&str, arrow::array::StringBuilder> = tag_keys + .iter() + .map(|&key| { + ( + key, + arrow::array::StringBuilder::with_capacity(self.files().len(), 64), + ) + }) + .collect(); + + for add in self.files() { + for &key in &tag_keys { + if let Some(value) = add + .tags + .as_ref() + .and_then(|tags| tags.get(key)) + .and_then(|val| val.as_deref()) + { + builder_map.get_mut(key).unwrap().append_value(value); + } else { + builder_map.get_mut(key).unwrap().append_null(); + } + } + } + + let mut arrays: Vec<(&str, ArrayRef)> = builder_map + .into_iter() + .map(|(key, mut builder)| (key, Arc::new(builder.finish()) as ArrayRef)) + .collect(); + // Sorted for consistent order + arrays.sort_by(|(key1, _), (key2, _)| key1.cmp(key2)); + if flatten { + Ok(arrow::record_batch::RecordBatch::try_from_iter( + arrays + .into_iter() + .map(|(key, array)| (format!("tags.{}", key), array)), + )?) + } else { + Ok(arrow::record_batch::RecordBatch::try_from_iter(vec![( + "tags", + Arc::new(StructArray::from( + arrays + .into_iter() + .map(|(key, array)| { + (Field::new(key, array.data_type().clone(), true), array) + }) + .collect_vec(), + )) as ArrayRef, + )])?) + } + } + + fn stats_as_batch( + &self, + flatten: bool, + ) -> Result { + let stats: Vec> = self + .files() + .iter() + .map(|f| { + f.get_stats() + .map_err(|err| DeltaTableError::InvalidJson { source: err }) + }) + .collect::>()?; + + let num_records = arrow::array::Int64Array::from( + stats + .iter() + .map(|maybe_stat| maybe_stat.as_ref().map(|stat| stat.num_records)) + .collect::>>(), + ); + let metadata = self.current_metadata().ok_or(DeltaTableError::NoMetadata)?; + let schema = &metadata.schema; + + #[derive(Debug)] + struct ColStats<'a> { + path: Vec<&'a str>, + null_count: Option, + min_values: Option, + max_values: Option, + } + + let mut columnar_stats: Vec = SchemaLeafIterator::new(schema) + .filter(|(_path, datatype)| !matches!(datatype, SchemaDataType::r#struct(_))) + .map(|(path, datatype)| -> Result { + let null_count: Option = stats + .iter() + .flat_map(|maybe_stat| { + maybe_stat + .as_ref() + .map(|stat| resolve_column_count_stat(&stat.null_count, &path)) + }) + .collect::>>() + .map(arrow::array::Int64Array::from) + .map(|arr| -> ArrayRef { Arc::new(arr) }); + + let arrow_type: arrow::datatypes::DataType = datatype.try_into()?; + + // Min and max are collected for primitive values, not list or maps + let min_values = if matches!(datatype, SchemaDataType::primitive(_)) { + stats + .iter() + .flat_map(|maybe_stat| { + maybe_stat + .as_ref() + .map(|stat| resolve_column_value_stat(&stat.min_values, &path)) + }) + .collect::>>() + .map(|min_values| { + json_value_to_array_general(&arrow_type, min_values.into_iter()) + }) + .transpose()? + } else { + None + }; + + let max_values = if matches!(datatype, SchemaDataType::primitive(_)) { + stats + .iter() + .flat_map(|maybe_stat| { + maybe_stat + .as_ref() + .map(|stat| resolve_column_value_stat(&stat.max_values, &path)) + }) + .collect::>>() + .map(|max_values| { + json_value_to_array_general(&arrow_type, max_values.into_iter()) + }) + .transpose()? + } else { + None + }; + + Ok(ColStats { + path, + null_count, + min_values, + max_values, + }) + }) + .collect::>()?; + + let mut out_columns: Vec<(Cow, ArrayRef)> = + vec![(Cow::Borrowed("num_records"), Arc::new(num_records))]; + if flatten { + for col_stats in columnar_stats { + if let Some(null_count) = col_stats.null_count { + out_columns.push(( + Cow::Owned(format!("null_count.{}", col_stats.path.join("."))), + null_count, + )); + } + if let Some(min_values) = col_stats.min_values { + out_columns.push(( + Cow::Owned(format!("min.{}", col_stats.path.join("."))), + min_values, + )); + } + if let Some(max_values) = col_stats.max_values { + out_columns.push(( + Cow::Owned(format!("max.{}", col_stats.path.join("."))), + max_values, + )); + } + } + } else { + let mut level = columnar_stats + .iter() + .map(|col_stat| col_stat.path.len()) + .max() + .unwrap_or(0); + + let combine_arrays = |sub_fields: &Vec, + getter: for<'a> fn(&'a ColStats) -> &'a Option| + -> Option { + let fields = sub_fields + .iter() + .flat_map(|sub_field| { + if let Some(values) = getter(sub_field) { + let field = Field::new( + sub_field + .path + .last() + .expect("paths must have at least one element"), + values.data_type().clone(), + false, + ); + Some((field, Arc::clone(values))) + } else { + None + } + }) + .collect::>(); + if fields.is_empty() { + None + } else { + Some(Arc::new(StructArray::from(fields))) + } + }; + + while level > 0 { + // Starting with most nested level, iteratively group null_count, min_values, max_values + // into StructArrays, until it is consolidated into a single array. + columnar_stats = columnar_stats + .into_iter() + .group_by(|col_stat| { + if col_stat.path.len() < level { + col_stat.path.clone() + } else { + col_stat.path[0..(level - 1)].to_vec() + } + }) + .into_iter() + .map(|(prefix, group)| { + let current_fields: Vec = group.into_iter().collect(); + if current_fields[0].path.len() < level { + debug_assert_eq!(current_fields.len(), 1); + current_fields.into_iter().next().unwrap() + } else { + ColStats { + path: prefix.to_vec(), + null_count: combine_arrays(¤t_fields, |sub_field| { + &sub_field.null_count + }), + min_values: combine_arrays(¤t_fields, |sub_field| { + &sub_field.min_values + }), + max_values: combine_arrays(¤t_fields, |sub_field| { + &sub_field.max_values + }), + } + } + }) + .collect(); + level -= 1; + } + debug_assert!(columnar_stats.len() == 1); + debug_assert!(columnar_stats + .iter() + .all(|col_stat| col_stat.path.is_empty())); + + if let Some(null_count) = columnar_stats[0].null_count.take() { + out_columns.push((Cow::Borrowed("null_count"), null_count)); + } + if let Some(min_values) = columnar_stats[0].min_values.take() { + out_columns.push((Cow::Borrowed("min"), min_values)); + } + if let Some(max_values) = columnar_stats[0].max_values.take() { + out_columns.push((Cow::Borrowed("max"), max_values)); + } + } + + Ok(arrow::record_batch::RecordBatch::try_from_iter( + out_columns, + )?) + } +} + +fn resolve_column_value_stat<'a>( + values: &'a HashMap, + path: &[&'a str], +) -> Option<&'a serde_json::Value> { + let mut current = values; + let (&name, path) = path.split_last()?; + for &segment in path { + current = current.get(segment)?.as_column()?; + } + let current = current.get(name)?; + current.as_value() +} + +fn resolve_column_count_stat( + values: &HashMap, + path: &[&str], +) -> Option { + let mut current = values; + let (&name, path) = path.split_last()?; + for &segment in path { + current = current.get(segment)?.as_column()?; + } + let current = current.get(name)?; + current.as_value() +} + +struct SchemaLeafIterator<'a> { + fields_remaining: VecDeque<(Vec<&'a str>, &'a SchemaDataType)>, +} + +impl<'a> SchemaLeafIterator<'a> { + fn new(schema: &'a SchemaTypeStruct) -> Self { + SchemaLeafIterator { + fields_remaining: schema + .get_fields() + .iter() + .map(|field| (vec![field.get_name()], field.get_type())) + .collect(), + } + } +} + +impl<'a> std::iter::Iterator for SchemaLeafIterator<'a> { + type Item = (Vec<&'a str>, &'a SchemaDataType); + + fn next(&mut self) -> Option { + if let Some((path, datatype)) = self.fields_remaining.pop_front() { + if let SchemaDataType::r#struct(struct_type) = datatype { + // push child fields to front + for field in struct_type.get_fields() { + let mut new_path = path.clone(); + new_path.push(field.get_name()); + self.fields_remaining + .push_front((new_path, field.get_type())); + } + }; + + Some((path, datatype)) + } else { + None + } + } +} + +fn json_value_to_array_general<'a>( + datatype: &arrow::datatypes::DataType, + values: impl Iterator, +) -> Result { + match datatype { + DataType::Boolean => Ok(Arc::new( + values + .map(|value| value.as_bool()) + .collect::(), + )), + DataType::Int64 | DataType::Int32 | DataType::Int16 | DataType::Int8 => { + let i64_arr: ArrayRef = + Arc::new(values.map(|value| value.as_i64()).collect::()); + Ok(arrow::compute::cast(&i64_arr, datatype)?) + } + DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _) => { + let f64_arr: ArrayRef = + Arc::new(values.map(|value| value.as_f64()).collect::()); + Ok(arrow::compute::cast(&f64_arr, datatype)?) + } + DataType::Utf8 => Ok(Arc::new( + values.map(|value| value.as_str()).collect::(), + )), + DataType::Binary => Ok(Arc::new( + values.map(|value| value.as_str()).collect::(), + )), + DataType::Timestamp(TimeUnit::Microsecond, None) => { + Ok(Arc::new(TimestampMicrosecondArray::from( + values + .map(|value| value.as_str().and_then(TimestampMicrosecondType::parse)) + .collect::>>(), + ))) + } + DataType::Date32 => Ok(Arc::new(Date32Array::from( + values + .map(|value| value.as_str().and_then(Date32Type::parse)) + .collect::>>(), + ))), + _ => Err(DeltaTableError::Generic("Invalid datatype".to_string())), + } +} diff --git a/rust/tests/add_actions_test.rs b/rust/tests/add_actions_test.rs new file mode 100644 index 0000000000..0ba51b6bbb --- /dev/null +++ b/rust/tests/add_actions_test.rs @@ -0,0 +1,448 @@ +#![cfg(feature = "arrow")] + +use arrow::array::{self, ArrayRef, StructArray}; +use arrow::compute::kernels::cast_utils::Parser; +use arrow::compute::sort_to_indices; +use arrow::datatypes::{DataType, Date32Type, Field, TimestampMicrosecondType}; +use arrow::record_batch::RecordBatch; +use std::sync::Arc; + +fn sort_batch_by(batch: &RecordBatch, column: &str) -> arrow::error::Result { + let sort_column = batch.column(batch.schema().column_with_name(column).unwrap().0); + let sort_indices = sort_to_indices(sort_column, None, None)?; + let schema = batch.schema(); + let sorted_columns: Vec<(&String, ArrayRef)> = schema + .fields() + .iter() + .zip(batch.columns().iter()) + .map(|(field, column)| { + Ok(( + field.name(), + arrow::compute::take(column, &sort_indices, None)?, + )) + }) + .collect::>()?; + RecordBatch::try_from_iter(sorted_columns) +} + +#[tokio::test] +async fn test_with_partitions() { + // test table with partitions + let path = "./tests/data/delta-0.8.0-null-partition"; + let table = deltalake::open_table(path).await.unwrap(); + let actions = table.get_state().add_actions_table(true).unwrap(); + let actions = sort_batch_by(&actions, "path").unwrap(); + + let mut expected_columns: Vec<(&str, ArrayRef)> = vec![ + ("path", Arc::new(array::StringArray::from(vec![ + "k=A/part-00000-b1f1dbbb-70bc-4970-893f-9bb772bf246e.c000.snappy.parquet", + "k=__HIVE_DEFAULT_PARTITION__/part-00001-8474ac85-360b-4f58-b3ea-23990c71b932.c000.snappy.parquet" + ]))), + ("size_bytes", Arc::new(array::Int64Array::from(vec![460, 460]))), + ("modification_time", Arc::new(arrow::array::TimestampMillisecondArray::from(vec![ + 1627990384000, 1627990384000 + ]))), + ("data_change", Arc::new(array::BooleanArray::from(vec![true, true]))), + ("partition.k", Arc::new(array::StringArray::from(vec![Some("A"), None]))), + ]; + let expected = RecordBatch::try_from_iter(expected_columns.clone()).unwrap(); + + assert_eq!(expected, actions); + + let actions = table.get_state().add_actions_table(false).unwrap(); + let actions = sort_batch_by(&actions, "path").unwrap(); + + expected_columns[4] = ( + "partition_values", + Arc::new(array::StructArray::from(vec![( + Field::new("k", DataType::Utf8, true), + Arc::new(array::StringArray::from(vec![Some("A"), None])) as ArrayRef, + )])), + ); + let expected = RecordBatch::try_from_iter(expected_columns).unwrap(); + + assert_eq!(expected, actions); +} + +#[tokio::test] +async fn test_without_partitions() { + // test table without partitions + let path = "./tests/data/simple_table"; + let table = deltalake::open_table(path).await.unwrap(); + + let actions = table.get_state().add_actions_table(true).unwrap(); + let actions = sort_batch_by(&actions, "path").unwrap(); + + let expected_columns: Vec<(&str, ArrayRef)> = vec![ + ( + "path", + Arc::new(array::StringArray::from(vec![ + "part-00000-2befed33-c358-4768-a43c-3eda0d2a499d-c000.snappy.parquet", + "part-00000-c1777d7d-89d9-4790-b38a-6ee7e24456b1-c000.snappy.parquet", + "part-00001-7891c33d-cedc-47c3-88a6-abcfb049d3b4-c000.snappy.parquet", + "part-00004-315835fe-fb44-4562-98f6-5e6cfa3ae45d-c000.snappy.parquet", + "part-00007-3a0e4727-de0d-41b6-81ef-5223cf40f025-c000.snappy.parquet", + ])), + ), + ( + "size_bytes", + Arc::new(array::Int64Array::from(vec![262, 262, 429, 429, 429])), + ), + ( + "modification_time", + Arc::new(arrow::array::TimestampMillisecondArray::from(vec![ + 1587968626000, + 1587968602000, + 1587968602000, + 1587968602000, + 1587968602000, + ])), + ), + ( + "data_change", + Arc::new(array::BooleanArray::from(vec![ + true, true, true, true, true, + ])), + ), + ]; + let expected = RecordBatch::try_from_iter(expected_columns.clone()).unwrap(); + + assert_eq!(expected, actions); + + let actions = table.get_state().add_actions_table(false).unwrap(); + let actions = sort_batch_by(&actions, "path").unwrap(); + + // For now, this column is ignored. + // expected_columns.push(( + // "partition_values", + // new_null_array(&DataType::Struct(vec![]), 5), + // )); + let expected = RecordBatch::try_from_iter(expected_columns.clone()).unwrap(); + + assert_eq!(expected, actions); +} + +#[tokio::test] +async fn test_with_stats() { + // test table with stats + let path = "./tests/data/delta-0.8.0"; + let table = deltalake::open_table(path).await.unwrap(); + let actions = table.get_state().add_actions_table(true).unwrap(); + let actions = sort_batch_by(&actions, "path").unwrap(); + + let expected_columns: Vec<(&str, ArrayRef)> = vec![ + ( + "path", + Arc::new(array::StringArray::from(vec![ + "part-00000-04ec9591-0b73-459e-8d18-ba5711d6cbe1-c000.snappy.parquet", + "part-00000-c9b90f86-73e6-46c8-93ba-ff6bfaf892a1-c000.snappy.parquet", + ])), + ), + ( + "size_bytes", + Arc::new(array::Int64Array::from(vec![440, 440])), + ), + ( + "modification_time", + Arc::new(arrow::array::TimestampMillisecondArray::from(vec![ + 1615043776000, + 1615043767000, + ])), + ), + ( + "data_change", + Arc::new(array::BooleanArray::from(vec![true, true])), + ), + ("num_records", Arc::new(array::Int64Array::from(vec![2, 2]))), + ( + "null_count.value", + Arc::new(array::Int64Array::from(vec![0, 0])), + ), + ("min.value", Arc::new(array::Int32Array::from(vec![2, 0]))), + ("max.value", Arc::new(array::Int32Array::from(vec![4, 2]))), + ]; + let expected = RecordBatch::try_from_iter(expected_columns.clone()).unwrap(); + + assert_eq!(expected, actions); +} + +#[tokio::test] +async fn test_only_struct_stats() { + // test table with no json stats + let path = "./tests/data/delta-1.2.1-only-struct-stats"; + let mut table = deltalake::open_table(path).await.unwrap(); + table.load_version(1).await.unwrap(); + + let actions = table.get_state().add_actions_table(true).unwrap(); + + let expected_columns: Vec<(&str, ArrayRef)> = vec![ + ( + "path", + Arc::new(array::StringArray::from(vec![ + "part-00000-7a509247-4f58-4453-9202-51d75dee59af-c000.snappy.parquet", + ])), + ), + ("size_bytes", Arc::new(array::Int64Array::from(vec![5489]))), + ( + "modification_time", + Arc::new(arrow::array::TimestampMillisecondArray::from(vec![ + 1666652373000, + ])), + ), + ( + "data_change", + Arc::new(array::BooleanArray::from(vec![true])), + ), + ("num_records", Arc::new(array::Int64Array::from(vec![1]))), + ( + "null_count.integer", + Arc::new(array::Int64Array::from(vec![0])), + ), + ("min.integer", Arc::new(array::Int32Array::from(vec![0]))), + ("max.integer", Arc::new(array::Int32Array::from(vec![0]))), + ( + "null_count.null", + Arc::new(array::Int64Array::from(vec![1])), + ), + ( + "null_count.boolean", + Arc::new(array::Int64Array::from(vec![0])), + ), + ( + "null_count.double", + Arc::new(array::Int64Array::from(vec![0])), + ), + ( + "min.double", + Arc::new(array::Float64Array::from(vec![1.234])), + ), + ( + "max.double", + Arc::new(array::Float64Array::from(vec![1.234])), + ), + ( + "null_count.decimal", + Arc::new(array::Int64Array::from(vec![0])), + ), + ( + "min.decimal", + Arc::new( + array::Decimal128Array::from_iter_values([-567800]) + .with_precision_and_scale(8, 5) + .unwrap(), + ), + ), + ( + "max.decimal", + Arc::new( + array::Decimal128Array::from_iter_values([-567800]) + .with_precision_and_scale(8, 5) + .unwrap(), + ), + ), + ( + "null_count.string", + Arc::new(array::Int64Array::from(vec![0])), + ), + ( + "min.string", + Arc::new(array::StringArray::from(vec!["string"])), + ), + ( + "max.string", + Arc::new(array::StringArray::from(vec!["string"])), + ), + ( + "null_count.binary", + Arc::new(array::Int64Array::from(vec![0])), + ), + ( + "null_count.date", + Arc::new(array::Int64Array::from(vec![0])), + ), + ( + "min.date", + Arc::new(array::Date32Array::from(vec![Date32Type::parse( + "2022-10-24", + )])), + ), + ( + "max.date", + Arc::new(array::Date32Array::from(vec![Date32Type::parse( + "2022-10-24", + )])), + ), + ( + "null_count.timestamp", + Arc::new(array::Int64Array::from(vec![0])), + ), + ( + "min.timestamp", + Arc::new(array::TimestampMicrosecondArray::from(vec![ + TimestampMicrosecondType::parse("2022-10-24T22:59:32.846Z"), + ])), + ), + ( + "max.timestamp", + Arc::new(array::TimestampMicrosecondArray::from(vec![ + TimestampMicrosecondType::parse("2022-10-24T22:59:32.846Z"), + ])), + ), + ( + "null_count.struct.struct_element", + Arc::new(array::Int64Array::from(vec![0])), + ), + ( + "min.struct.struct_element", + Arc::new(array::StringArray::from(vec!["struct_value"])), + ), + ( + "max.struct.struct_element", + Arc::new(array::StringArray::from(vec!["struct_value"])), + ), + ("null_count.map", Arc::new(array::Int64Array::from(vec![0]))), + ( + "null_count.array", + Arc::new(array::Int64Array::from(vec![0])), + ), + ( + "null_count.nested_struct.struct_element.nested_struct_element", + Arc::new(array::Int64Array::from(vec![0])), + ), + ( + "min.nested_struct.struct_element.nested_struct_element", + Arc::new(array::StringArray::from(vec!["nested_struct_value"])), + ), + ( + "max.nested_struct.struct_element.nested_struct_element", + Arc::new(array::StringArray::from(vec!["nested_struct_value"])), + ), + ( + "null_count.struct_of_array_of_map.struct_element", + Arc::new(array::Int64Array::from(vec![0])), + ), + ( + "tags.INSERTION_TIME", + Arc::new(array::StringArray::from(vec!["1666652373000000"])), + ), + ( + "tags.OPTIMIZE_TARGET_SIZE", + Arc::new(array::StringArray::from(vec!["268435456"])), + ), + ]; + let expected = RecordBatch::try_from_iter(expected_columns.clone()).unwrap(); + + assert_eq!( + expected + .schema() + .fields() + .iter() + .map(|field| field.name().as_str()) + .collect::>(), + actions + .schema() + .fields() + .iter() + .map(|field| field.name().as_str()) + .collect::>() + ); + assert_eq!(expected, actions); + + let actions = table.get_state().add_actions_table(false).unwrap(); + // For brevity, just checking a few nested columns in stats + + assert_eq!( + actions + .get_field_at_path(&vec![ + "null_count", + "nested_struct", + "struct_element", + "nested_struct_element" + ]) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(), + &array::Int64Array::from(vec![0]), + ); + + assert_eq!( + actions + .get_field_at_path(&vec![ + "min", + "nested_struct", + "struct_element", + "nested_struct_element" + ]) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(), + &array::StringArray::from(vec!["nested_struct_value"]), + ); + + assert_eq!( + actions + .get_field_at_path(&vec![ + "max", + "nested_struct", + "struct_element", + "nested_struct_element" + ]) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(), + &array::StringArray::from(vec!["nested_struct_value"]), + ); + + assert_eq!( + actions + .get_field_at_path(&vec![ + "null_count", + "struct_of_array_of_map", + "struct_element" + ]) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(), + &array::Int64Array::from(vec![0]) + ); + + assert_eq!( + actions + .get_field_at_path(&vec!["tags", "OPTIMIZE_TARGET_SIZE"]) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(), + &array::StringArray::from(vec!["268435456"]) + ); +} + +/// Trait to make it easier to access nested fields +trait NestedTabular { + fn get_field_at_path(&self, path: &[&str]) -> Option; +} + +impl NestedTabular for RecordBatch { + fn get_field_at_path(&self, path: &[&str]) -> Option { + // First, get array in the batch + let (first_key, remainder) = path.split_at(1); + let mut col = self.column(self.schema().column_with_name(first_key[0])?.0); + + if remainder.is_empty() { + return Some(Arc::clone(col)); + } + + for segment in remainder { + col = col + .as_any() + .downcast_ref::()? + .column_by_name(segment)?; + } + + Some(Arc::clone(col)) + } +}