Skip to content

Commit

Permalink
fix: avoid writing statistics for binary columns to fix JSON error (d…
Browse files Browse the repository at this point in the history
…elta-io#1498)

# Description
Avoid writing statistics for binary columns to fix JSON error thrown by
Arrow

# Related Issue(s)
 - closes delta-io#1493

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
Co-authored-by: R. Tyler Croy <rtyler@brokenco.de>
  • Loading branch information
3 people authored Jul 15, 2023
1 parent 4a4aaa9 commit 312d1c2
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 24 deletions.
29 changes: 28 additions & 1 deletion rust/src/action/checkpoints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,33 @@ pub async fn cleanup_expired_logs_for(
}
}

/// Filter binary from the schema so that it isn't serialized into JSON,
/// as arrow currently does not support this.
fn filter_binary(schema: &Schema) -> Schema {
Schema::new(
schema
.get_fields()
.iter()
.flat_map(|f| match f.get_type() {
SchemaDataType::primitive(p) => {
if p != "binary" {
Some(f.clone())
} else {
None
}
}
SchemaDataType::r#struct(s) => Some(SchemaField::new(
f.get_name().to_string(),
SchemaDataType::r#struct(filter_binary(&Schema::new(s.get_fields().clone()))),
f.is_nullable(),
f.get_metadata().clone(),
)),
_ => Some(f.clone()),
})
.collect::<Vec<_>>(),
)
}

fn parquet_bytes_from_state(
state: &DeltaTableState,
) -> Result<(CheckPoint, bytes::Bytes), ProtocolError> {
Expand Down Expand Up @@ -357,7 +384,7 @@ fn parquet_bytes_from_state(

// Create the arrow schema that represents the Checkpoint parquet file.
let arrow_schema = delta_log_schema_for_table(
<ArrowSchema as TryFrom<&Schema>>::try_from(&current_metadata.schema)?,
<ArrowSchema as TryFrom<&Schema>>::try_from(&filter_binary(&current_metadata.schema))?,
current_metadata.partition_columns.as_slice(),
use_extended_remove_schema,
);
Expand Down
10 changes: 3 additions & 7 deletions rust/src/action/parquet_read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::collections::HashMap;

use chrono::{SecondsFormat, TimeZone, Utc};
use num_bigint::BigInt;
use num_traits::cast::ToPrimitive;
use parquet::record::{Field, ListAccessor, MapAccessor, RowAccessor};
use serde_json::json;

Expand Down Expand Up @@ -255,12 +254,9 @@ fn primitive_parquet_field_to_json_value(field: &Field) -> Result<serde_json::Va
Field::Float(value) => Ok(json!(value)),
Field::Double(value) => Ok(json!(value)),
Field::Str(value) => Ok(json!(value)),
Field::Decimal(decimal) => match BigInt::from_signed_bytes_be(decimal.data()).to_f64() {
Some(int) => Ok(json!(
int / (10_i64.pow((decimal.scale()).try_into().unwrap()) as f64)
)),
_ => Err("Invalid type for min/max values."),
},
Field::Decimal(decimal) => Ok(serde_json::Value::String(
BigInt::from_signed_bytes_be(decimal.data()).to_string(),
)),
Field::TimestampMicros(timestamp) => Ok(serde_json::Value::String(
convert_timestamp_micros_to_string(*timestamp)?,
)),
Expand Down
45 changes: 29 additions & 16 deletions rust/src/writer/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ enum StatsScalar {
Date(chrono::NaiveDate),
Timestamp(chrono::NaiveDateTime),
// We are serializing to f64 later and the ordering should be the same
Decimal(f64),
Decimal(String),
String(String),
Bytes(Vec<u8>),
Uuid(uuid::Uuid),
Expand Down Expand Up @@ -157,7 +157,7 @@ impl StatsScalar {
(Statistics::Int32(v), Some(LogicalType::Decimal { scale, .. })) => {
let val = get_stat!(v) as f64 / 10.0_f64.powi(*scale);
// Spark serializes these as numbers
Ok(Self::Decimal(val))
Ok(Self::Decimal(val.to_string()))
}
(Statistics::Int32(v), _) => Ok(Self::Int32(get_stat!(v))),
// Int64 can be timestamp, decimal, or integer
Expand All @@ -184,7 +184,7 @@ impl StatsScalar {
(Statistics::Int64(v), Some(LogicalType::Decimal { scale, .. })) => {
let val = get_stat!(v) as f64 / 10.0_f64.powi(*scale);
// Spark serializes these as numbers
Ok(Self::Decimal(val))
Ok(Self::Decimal(val.to_string()))
}
(Statistics::Int64(v), _) => Ok(Self::Int64(get_stat!(v))),
(Statistics::Float(v), _) => Ok(Self::Float32(get_stat!(v))),
Expand Down Expand Up @@ -220,16 +220,16 @@ impl StatsScalar {

let val = if val.len() <= 4 {
let mut bytes = [0; 4];
bytes[..val.len()].copy_from_slice(val);
i32::from_be_bytes(bytes) as f64
bytes[(4 - val.len())..4].copy_from_slice(val);
i32::from_be_bytes(bytes).to_string()
} else if val.len() <= 8 {
let mut bytes = [0; 8];
bytes[..val.len()].copy_from_slice(val);
i64::from_be_bytes(bytes) as f64
bytes[(8 - val.len())..8].copy_from_slice(val);
i64::from_be_bytes(bytes).to_string()
} else if val.len() <= 16 {
let mut bytes = [0; 16];
bytes[..val.len()].copy_from_slice(val);
i128::from_be_bytes(bytes) as f64
bytes[(16 - val.len())..16].copy_from_slice(val);
i128::from_be_bytes(bytes).to_string()
} else {
return Err(DeltaWriterError::StatsParsingFailed {
debug_value: format!("{val:?}"),
Expand All @@ -240,8 +240,21 @@ impl StatsScalar {
});
};

let val = val / 10.0_f64.powi(*scale);
Ok(Self::Decimal(val))
let decimal_string = if val.len() > *scale as usize {
let (integer_part, fractional_part) = val.split_at(val.len() - *scale as usize);
if fractional_part.is_empty() {
integer_part.to_string()
} else {
format!("{}.{}", integer_part, fractional_part)
}
} else if *scale < 0 {
let abs_scale = scale.unsigned_abs() as usize;
let decimal_zeros = "0".repeat(abs_scale);
format!("{}{}", val, decimal_zeros)
} else {
format!("0.{}", val)
};
Ok(Self::Decimal(decimal_string))
}
(Statistics::FixedLenByteArray(v), Some(LogicalType::Uuid)) => {
let val = if use_min {
Expand Down Expand Up @@ -528,15 +541,15 @@ mod tests {
scale: 3,
precision: 4,
}),
Value::from(1.234),
Value::from("1.234"),
),
(
simple_parquet_stat!(Statistics::Int32, 1234),
Some(LogicalType::Decimal {
scale: -1,
precision: 4,
}),
Value::from(12340.0),
Value::from("12340"),
),
(
simple_parquet_stat!(Statistics::Int32, 737821),
Expand Down Expand Up @@ -573,15 +586,15 @@ mod tests {
scale: 3,
precision: 4,
}),
Value::from(1.234),
Value::from("1.234"),
),
(
simple_parquet_stat!(Statistics::Int64, 1234),
Some(LogicalType::Decimal {
scale: -1,
precision: 4,
}),
Value::from(12340.0),
Value::from("12340"),
),
(
simple_parquet_stat!(Statistics::Int64, 1234),
Expand All @@ -607,7 +620,7 @@ mod tests {
scale: 3,
precision: 16,
}),
Value::from(1243124142314.423),
Value::from("1243124142314.423"),
),
(
simple_parquet_stat!(
Expand Down
80 changes: 80 additions & 0 deletions rust/tests/checkpoint_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,90 @@ mod fs_common;

#[cfg(all(feature = "arrow", feature = "parquet"))]
mod simple_checkpoint {
use arrow::datatypes::Schema as ArrowSchema;
use arrow_array::{BinaryArray, Decimal128Array, RecordBatch};
use arrow_schema::{DataType, Field};
use deltalake::writer::{DeltaWriter, RecordBatchWriter};
use deltalake::*;
use pretty_assertions::assert_eq;
use std::collections::HashMap;
use std::error::Error;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Arc;

struct Context {
pub table: DeltaTable,
}

async fn setup_test() -> Result<Context, Box<dyn Error>> {
let columns = vec![
SchemaField::new(
"bin".to_owned(),
SchemaDataType::primitive("binary".to_owned()),
false,
HashMap::new(),
),
SchemaField::new(
"dec".to_owned(),
SchemaDataType::primitive("decimal(23,0)".to_owned()),
false,
HashMap::new(),
),
];

let tmp_dir = tempdir::TempDir::new("opt_table").unwrap();
let table_uri = tmp_dir.path().to_str().to_owned().unwrap();
let dt = DeltaOps::try_from_uri(table_uri)
.await?
.create()
.with_columns(columns)
.await?;

Ok(Context { table: dt })
}

fn get_batch(items: Vec<&[u8]>, decimals: Vec<i128>) -> Result<RecordBatch, Box<dyn Error>> {
let x_array = BinaryArray::from(items);
let dec_array = Decimal128Array::from(decimals).with_precision_and_scale(23, 0)?;

Ok(RecordBatch::try_new(
Arc::new(ArrowSchema::new(vec![
Field::new("bin", DataType::Binary, false),
Field::new("dec", DataType::Decimal128(23, 0), false),
])),
vec![Arc::new(x_array), Arc::new(dec_array)],
)?)
}

async fn write(
writer: &mut RecordBatchWriter,
table: &mut DeltaTable,
batch: RecordBatch,
) -> Result<(), DeltaTableError> {
writer.write(batch).await?;
writer.flush_and_commit(table).await?;
Ok(())
}

#[tokio::test]
async fn test_checkpoint_write_binary_stats() -> Result<(), Box<dyn Error>> {
let context = setup_test().await?;
let mut dt = context.table;
let mut writer = RecordBatchWriter::for_table(&dt)?;

write(
&mut writer,
&mut dt,
get_batch(vec![&[1, 2]], vec![18446744073709551614])?,
)
.await?;

// Just checking that this doesn't fail. https://github.com/delta-io/delta-rs/issues/1493
checkpoints::create_checkpoint(&dt).await?;

Ok(())
}

#[tokio::test]
async fn simple_checkpoint_test() {
Expand Down

0 comments on commit 312d1c2

Please sign in to comment.