Skip to content

Commit

Permalink
vtab::arrow: support UUID extension type
Browse files Browse the repository at this point in the history
  • Loading branch information
ajwerner committed Nov 20, 2024
1 parent 2bd811e commit b96bf56
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 26 deletions.
2 changes: 1 addition & 1 deletion crates/duckdb/src/appender/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl Appender<'_> {
let schema = record_batch.schema();
let mut logical_type: Vec<LogicalTypeHandle> = vec![];
for field in schema.fields() {
let logical_t = to_duckdb_logical_type(field.data_type())
let logical_t = to_duckdb_logical_type(field.data_type(), field.metadata())
.map_err(|_op| Error::ArrowTypeToDuckdbType(field.to_string(), field.data_type().clone()))?;
logical_type.push(logical_t);
}
Expand Down
11 changes: 10 additions & 1 deletion crates/duckdb/src/core/data_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::{
};
use crate::ffi::{
duckdb_create_data_chunk, duckdb_data_chunk, duckdb_data_chunk_get_column_count, duckdb_data_chunk_get_size,
duckdb_data_chunk_get_vector, duckdb_data_chunk_set_size, duckdb_destroy_data_chunk,
duckdb_data_chunk_get_vector, duckdb_data_chunk_set_size, duckdb_destroy_data_chunk, duckdb_vector_get_column_type,
};

/// Handle to the DataChunk in DuckDB.
Expand Down Expand Up @@ -59,6 +59,15 @@ impl DataChunkHandle {
StructVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) })
}

/// Get the logical type of the vector at the column index: `idx`.
pub fn logical_type(&self, idx: usize) -> LogicalTypeHandle {
unsafe {
LogicalTypeHandle::new(duckdb_vector_get_column_type(duckdb_data_chunk_get_vector(
self.ptr, idx as u64,
)))
}
}

/// Set the size of the data chunk
pub fn set_len(&self, new_len: usize) {
unsafe { duckdb_data_chunk_set_size(self.ptr, new_len as u64) };
Expand Down
156 changes: 132 additions & 24 deletions crates/duckdb/src/vtab/arrow.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{BindInfo, DataChunkHandle, Free, FunctionInfo, InitInfo, LogicalTypeHandle, LogicalTypeId, VTab};
use std::ptr::null_mut;
use std::{collections::HashMap, ptr::null_mut};

use crate::core::{ArrayVector, FlatVector, Inserter, ListVector, StructVector, Vector};
use arrow::{
Expand Down Expand Up @@ -88,7 +88,7 @@ impl VTab for ArrowVTab {
for f in rb.schema().fields() {
let name = f.name();
let data_type = f.data_type();
let logical_type = to_duckdb_logical_type(data_type)?;
let logical_type = to_duckdb_logical_type(data_type, f.metadata())?;
bind.add_result_column(name, logical_type);
}
(*data).rb = Box::into_raw(Box::new(rb));
Expand Down Expand Up @@ -128,8 +128,15 @@ impl VTab for ArrowVTab {
}
}

const EXTENSION_NAME_KEY: &str = "ARROW:extension:name";
const UUID_EXTENSION_NAME: &str = "arrow.uuid";
const UUID_LENGTH: usize = 16;

/// Convert arrow DataType to duckdb type id
pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn std::error::Error>> {
pub fn to_duckdb_type_id(
data_type: &DataType,
metadata: &HashMap<String, String>,
) -> Result<LogicalTypeId, Box<dyn std::error::Error>> {
use LogicalTypeId::*;

let type_id = match data_type {
Expand Down Expand Up @@ -157,7 +164,17 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
DataType::Time64(_) => Time,
DataType::Duration(_) => Interval,
DataType::Interval(_) => Interval,
DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) => Blob,
DataType::FixedSizeBinary(_) => {
if metadata
.get(EXTENSION_NAME_KEY)
.map_or(false, |name| name == UUID_EXTENSION_NAME)
{
Uuid
} else {
Blob
}
}
DataType::Binary | DataType::LargeBinary => Blob,
DataType::Utf8 | DataType::LargeUtf8 => Varchar,
DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => List,
DataType::Struct(_) => Struct,
Expand All @@ -177,21 +194,28 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
}

/// Convert arrow DataType to duckdb logical type
pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalTypeHandle, Box<dyn std::error::Error>> {
pub fn to_duckdb_logical_type(
data_type: &DataType,
metadata: &HashMap<String, String>,
) -> Result<LogicalTypeHandle, Box<dyn std::error::Error>> {
match data_type {
DataType::Dictionary(_, value_type) => to_duckdb_logical_type(value_type),
DataType::Dictionary(_, value_type) => to_duckdb_logical_type(value_type, &HashMap::new()),
DataType::Struct(fields) => {
let mut shape = vec![];
for field in fields.iter() {
shape.push((field.name().as_str(), to_duckdb_logical_type(field.data_type())?));
shape.push((
field.name().as_str(),
to_duckdb_logical_type(field.data_type(), field.metadata())?,
));
}
Ok(LogicalTypeHandle::struct_type(shape.as_slice()))
}
DataType::List(child) | DataType::LargeList(child) => {
Ok(LogicalTypeHandle::list(&to_duckdb_logical_type(child.data_type())?))
}
DataType::List(child) | DataType::LargeList(child) => Ok(LogicalTypeHandle::list(&to_duckdb_logical_type(
child.data_type(),
child.metadata(),
)?)),
DataType::FixedSizeList(child, array_size) => Ok(LogicalTypeHandle::array(
&to_duckdb_logical_type(child.data_type())?,
&to_duckdb_logical_type(child.data_type(), child.metadata())?,
*array_size as u64,
)),
DataType::Decimal128(width, scale) if *scale > 0 => {
Expand All @@ -203,8 +227,8 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalTypeHandle,
| DataType::LargeUtf8
| DataType::Binary
| DataType::LargeBinary
| DataType::FixedSizeBinary(_) => Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type)?)),
dtype if dtype.is_primitive() => Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type)?)),
| DataType::FixedSizeBinary(_) => Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type, metadata)?)),
dtype if dtype.is_primitive() => Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type, metadata)?)),
_ => Err(format!(
"Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs"
)
Expand Down Expand Up @@ -245,8 +269,15 @@ pub fn record_batch_to_duckdb_data_chunk(
DataType::Binary => {
binary_array_to_vector(as_generic_binary_array(col.as_ref()), &mut chunk.flat_vector(i));
}
DataType::FixedSizeBinary(_) => {
fixed_size_binary_array_to_vector(col.as_ref().as_fixed_size_binary(), &mut chunk.flat_vector(i));
DataType::FixedSizeBinary(length) => {
if chunk.logical_type(i).id() == LogicalTypeId::Uuid {
if *length != UUID_LENGTH as i32 {
return Err(format!("UUID FixedSizeBinaryArray must have value length of 16").into());
}
uuid_array_to_vector(col.as_ref().as_fixed_size_binary(), &mut chunk.flat_vector(i));
} else {
fixed_size_binary_array_to_vector(col.as_ref().as_fixed_size_binary(), &mut chunk.flat_vector(i));
}
}
DataType::LargeBinary => {
large_binary_array_to_vector(
Expand Down Expand Up @@ -284,6 +315,20 @@ pub fn record_batch_to_duckdb_data_chunk(
Ok(())
}

fn uuid_array_to_vector(array: &FixedSizeBinaryArray, out_vector: &mut FlatVector) {
let out_data: &mut [i128] = out_vector.as_mut_slice();
for (i, value) in array.values().chunks_exact(UUID_LENGTH).enumerate() {
let value: [u8; UUID_LENGTH] = value.try_into().unwrap();
let value = i128::from_be_bytes(value);
// For whatever reason, DuckDB internally uses a signed integer to represent UUIDs
// but we need to swap the sign bit in order to maintain fidelity.
const MASK: i128 = (1u128 << 127) as _;
let value = value ^ MASK;
out_data[i] = value;
}
set_nulls_in_flat_vector(array, out_vector);
}

fn primitive_array_to_flat_vector<T: ArrowPrimitiveType>(array: &PrimitiveArray<T>, out_vector: &mut FlatVector) {
// assert!(array.len() <= out_vector.capacity());
out_vector.copy::<T::Native>(array.values());
Expand Down Expand Up @@ -698,24 +743,29 @@ fn set_nulls_in_list_vector(array: &dyn Array, out_vector: &mut ListVector) {

#[cfg(test)]
mod test {
use super::{arrow_recordbatch_to_query_params, ArrowVTab};
use crate::{Connection, Result};
use super::{arrow_recordbatch_to_query_params, ArrowVTab, UUID_LENGTH};
use crate::{
vtab::arrow::{EXTENSION_NAME_KEY, UUID_EXTENSION_NAME},
Connection, Result,
};
use arrow::{
array::{
Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array,
DurationSecondArray, FixedSizeListArray, GenericByteArray, GenericListArray, Int32Array,
IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeStringArray, ListArray,
OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray, GenericByteArray, GenericListArray,
Int32Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray,
LargeStringArray, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray,
Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray,
},
buffer::{OffsetBuffer, ScalarBuffer},
buffer::{Buffer, NullBuffer, OffsetBuffer, ScalarBuffer},
datatypes::{
i256, ArrowPrimitiveType, ByteArrayType, DataType, DurationSecondType, Field, Fields, IntervalDayTimeType,
IntervalMonthDayNanoType, IntervalYearMonthType, Schema,
},
record_batch::RecordBatch,
record_batch::RecordBatch, util::pretty,
};
use std::{error::Error, sync::Arc};
use std::{collections::HashMap, error::Error, sync::Arc};
use uuid::Uuid;

#[test]
fn test_vtab_arrow() -> Result<(), Box<dyn Error>> {
Expand Down Expand Up @@ -1264,4 +1314,62 @@ mod test {
assert_eq!(column.len(), 1);
assert_eq!(column.value(0), b"test");
}

#[test]
fn test_uuid_roundtrip() -> Result<(), Box<dyn Error>> {
let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;
let schema = Schema::new(vec![Field::new(
"a",
DataType::FixedSizeBinary(UUID_LENGTH as i32),
true,
)
.with_metadata(HashMap::from([(
EXTENSION_NAME_KEY.to_string(),
UUID_EXTENSION_NAME.to_string(),
)]))]);

let uuids = vec![
Uuid::from_u128(0xa1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8u128),
Uuid::from_u128(0xb1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8u128),
Uuid::from_u128(0),
Uuid::from_u128(0x41),
Uuid::from_u128(0x42),
];
let buf = uuids
.iter()
.flat_map(|uuid| uuid.as_bytes())
.copied()
.collect::<Vec<u8>>();
let buf = Buffer::from_vec(buf);
let rb = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(FixedSizeBinaryArray::new(
UUID_LENGTH as i32,
buf,
Some(NullBuffer::from(vec![true, true, false, true, true])),
))],
)?;
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("SELECT a, typeof(a) FROM arrow(?, ?)")?;
let mut arr = stmt.query_arrow(param)?;
let rb = arr.next().expect("no record batch");
let rb = [rb];
let printed = pretty::pretty_format_batches(&rb).unwrap();
assert_eq!(
"\
+--------------------------------------+-----------+
| a | typeof(a) |
+--------------------------------------+-----------+
| a1a2a3a4-b1b2-c1c2-d1d2-d3d4d5d6d7d8 | UUID |
| b1a2a3a4-b1b2-c1c2-d1d2-d3d4d5d6d7d8 | UUID |
| | UUID |
| 00000000-0000-0000-0000-000000000041 | UUID |
| 00000000-0000-0000-0000-000000000042 | UUID |
+--------------------------------------+-----------+",
printed.to_string(),
"{printed}"
);
Ok(())
}
}

0 comments on commit b96bf56

Please sign in to comment.