Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vtab::arrow: support UUID extension type #405

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/duckdb/src/appender/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl Appender<'_> {
let schema = record_batch.schema();
let mut logical_type: Vec<LogicalTypeHandle> = vec![];
for field in schema.fields() {
let logical_t = to_duckdb_logical_type(field.data_type())
let logical_t = to_duckdb_logical_type(field.data_type(), field.metadata())
.map_err(|_op| Error::ArrowTypeToDuckdbType(field.to_string(), field.data_type().clone()))?;
logical_type.push(logical_t);
}
Expand Down
11 changes: 10 additions & 1 deletion crates/duckdb/src/core/data_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::{
};
use crate::ffi::{
duckdb_create_data_chunk, duckdb_data_chunk, duckdb_data_chunk_get_column_count, duckdb_data_chunk_get_size,
duckdb_data_chunk_get_vector, duckdb_data_chunk_set_size, duckdb_destroy_data_chunk,
duckdb_data_chunk_get_vector, duckdb_data_chunk_set_size, duckdb_destroy_data_chunk, duckdb_vector_get_column_type,
};

/// Handle to the DataChunk in DuckDB.
Expand Down Expand Up @@ -59,6 +59,15 @@ impl DataChunkHandle {
StructVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) })
}

/// Get the logical type of the vector at the column index: `idx`.
pub fn logical_type(&self, idx: usize) -> LogicalTypeHandle {
unsafe {
LogicalTypeHandle::new(duckdb_vector_get_column_type(duckdb_data_chunk_get_vector(
self.ptr, idx as u64,
)))
}
}

/// Set the size of the data chunk
pub fn set_len(&self, new_len: usize) {
unsafe { duckdb_data_chunk_set_size(self.ptr, new_len as u64) };
Expand Down
156 changes: 132 additions & 24 deletions crates/duckdb/src/vtab/arrow.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{BindInfo, DataChunkHandle, Free, FunctionInfo, InitInfo, LogicalTypeHandle, LogicalTypeId, VTab};
use std::ptr::null_mut;
use std::{collections::HashMap, ptr::null_mut};

use crate::core::{ArrayVector, FlatVector, Inserter, ListVector, StructVector, Vector};
use arrow::{
Expand Down Expand Up @@ -88,7 +88,7 @@ impl VTab for ArrowVTab {
for f in rb.schema().fields() {
let name = f.name();
let data_type = f.data_type();
let logical_type = to_duckdb_logical_type(data_type)?;
let logical_type = to_duckdb_logical_type(data_type, f.metadata())?;
bind.add_result_column(name, logical_type);
}
(*data).rb = Box::into_raw(Box::new(rb));
Expand Down Expand Up @@ -128,8 +128,15 @@ impl VTab for ArrowVTab {
}
}

const EXTENSION_NAME_KEY: &str = "ARROW:extension:name";
const UUID_EXTENSION_NAME: &str = "arrow.uuid";
const UUID_LENGTH: usize = 16;

/// Convert arrow DataType to duckdb type id
pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn std::error::Error>> {
pub fn to_duckdb_type_id(
data_type: &DataType,
metadata: &HashMap<String, String>,
) -> Result<LogicalTypeId, Box<dyn std::error::Error>> {
use LogicalTypeId::*;

let type_id = match data_type {
Expand Down Expand Up @@ -157,7 +164,17 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
DataType::Time64(_) => Time,
DataType::Duration(_) => Interval,
DataType::Interval(_) => Interval,
DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) => Blob,
DataType::FixedSizeBinary(_) => {
if metadata
.get(EXTENSION_NAME_KEY)
.map_or(false, |name| name == UUID_EXTENSION_NAME)
{
Uuid
} else {
Blob
}
}
DataType::Binary | DataType::LargeBinary => Blob,
DataType::Utf8 | DataType::LargeUtf8 => Varchar,
DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => List,
DataType::Struct(_) => Struct,
Expand All @@ -177,21 +194,28 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
}

/// Convert arrow DataType to duckdb logical type
pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalTypeHandle, Box<dyn std::error::Error>> {
pub fn to_duckdb_logical_type(
data_type: &DataType,
metadata: &HashMap<String, String>,
) -> Result<LogicalTypeHandle, Box<dyn std::error::Error>> {
match data_type {
DataType::Dictionary(_, value_type) => to_duckdb_logical_type(value_type),
DataType::Dictionary(_, value_type) => to_duckdb_logical_type(value_type, &HashMap::new()),
DataType::Struct(fields) => {
let mut shape = vec![];
for field in fields.iter() {
shape.push((field.name().as_str(), to_duckdb_logical_type(field.data_type())?));
shape.push((
field.name().as_str(),
to_duckdb_logical_type(field.data_type(), field.metadata())?,
));
}
Ok(LogicalTypeHandle::struct_type(shape.as_slice()))
}
DataType::List(child) | DataType::LargeList(child) => {
Ok(LogicalTypeHandle::list(&to_duckdb_logical_type(child.data_type())?))
}
DataType::List(child) | DataType::LargeList(child) => Ok(LogicalTypeHandle::list(&to_duckdb_logical_type(
child.data_type(),
child.metadata(),
)?)),
DataType::FixedSizeList(child, array_size) => Ok(LogicalTypeHandle::array(
&to_duckdb_logical_type(child.data_type())?,
&to_duckdb_logical_type(child.data_type(), child.metadata())?,
*array_size as u64,
)),
DataType::Decimal128(width, scale) if *scale > 0 => {
Expand All @@ -203,8 +227,8 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalTypeHandle,
| DataType::LargeUtf8
| DataType::Binary
| DataType::LargeBinary
| DataType::FixedSizeBinary(_) => Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type)?)),
dtype if dtype.is_primitive() => Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type)?)),
| DataType::FixedSizeBinary(_) => Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type, metadata)?)),
dtype if dtype.is_primitive() => Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type, metadata)?)),
_ => Err(format!(
"Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs"
)
Expand Down Expand Up @@ -245,8 +269,15 @@ pub fn record_batch_to_duckdb_data_chunk(
DataType::Binary => {
binary_array_to_vector(as_generic_binary_array(col.as_ref()), &mut chunk.flat_vector(i));
}
DataType::FixedSizeBinary(_) => {
fixed_size_binary_array_to_vector(col.as_ref().as_fixed_size_binary(), &mut chunk.flat_vector(i));
DataType::FixedSizeBinary(length) => {
if chunk.logical_type(i).id() == LogicalTypeId::Uuid {
if *length != UUID_LENGTH as i32 {
return Err(format!("UUID FixedSizeBinaryArray must have value length of 16").into());
}
uuid_array_to_vector(col.as_ref().as_fixed_size_binary(), &mut chunk.flat_vector(i));
} else {
fixed_size_binary_array_to_vector(col.as_ref().as_fixed_size_binary(), &mut chunk.flat_vector(i));
}
}
DataType::LargeBinary => {
large_binary_array_to_vector(
Expand Down Expand Up @@ -284,6 +315,20 @@ pub fn record_batch_to_duckdb_data_chunk(
Ok(())
}

fn uuid_array_to_vector(array: &FixedSizeBinaryArray, out_vector: &mut FlatVector) {
let out_data: &mut [i128] = out_vector.as_mut_slice();
for (i, value) in array.values().chunks_exact(UUID_LENGTH).enumerate() {
let value: [u8; UUID_LENGTH] = value.try_into().unwrap();
let value = i128::from_be_bytes(value);
// For whatever reason, DuckDB internally uses a signed integer to represent UUIDs
// but we need to swap the sign bit in order to maintain fidelity.
const MASK: i128 = (1u128 << 127) as _;
let value = value ^ MASK;
out_data[i] = value;
}
set_nulls_in_flat_vector(array, out_vector);
}

fn primitive_array_to_flat_vector<T: ArrowPrimitiveType>(array: &PrimitiveArray<T>, out_vector: &mut FlatVector) {
// assert!(array.len() <= out_vector.capacity());
out_vector.copy::<T::Native>(array.values());
Expand Down Expand Up @@ -698,24 +743,29 @@ fn set_nulls_in_list_vector(array: &dyn Array, out_vector: &mut ListVector) {

#[cfg(test)]
mod test {
use super::{arrow_recordbatch_to_query_params, ArrowVTab};
use crate::{Connection, Result};
use super::{arrow_recordbatch_to_query_params, ArrowVTab, UUID_LENGTH};
use crate::{
vtab::arrow::{EXTENSION_NAME_KEY, UUID_EXTENSION_NAME},
Connection, Result,
};
use arrow::{
array::{
Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array,
DurationSecondArray, FixedSizeListArray, GenericByteArray, GenericListArray, Int32Array,
IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeStringArray, ListArray,
OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray, GenericByteArray, GenericListArray,
Int32Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray,
LargeStringArray, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray,
Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray,
},
buffer::{OffsetBuffer, ScalarBuffer},
buffer::{Buffer, NullBuffer, OffsetBuffer, ScalarBuffer},
datatypes::{
i256, ArrowPrimitiveType, ByteArrayType, DataType, DurationSecondType, Field, Fields, IntervalDayTimeType,
IntervalMonthDayNanoType, IntervalYearMonthType, Schema,
},
record_batch::RecordBatch,
record_batch::RecordBatch, util::pretty,
};
use std::{error::Error, sync::Arc};
use std::{collections::HashMap, error::Error, sync::Arc};
use uuid::Uuid;

#[test]
fn test_vtab_arrow() -> Result<(), Box<dyn Error>> {
Expand Down Expand Up @@ -1264,4 +1314,62 @@ mod test {
assert_eq!(column.len(), 1);
assert_eq!(column.value(0), b"test");
}

#[test]
fn test_uuid_roundtrip() -> Result<(), Box<dyn Error>> {
let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;
let schema = Schema::new(vec![Field::new(
"a",
DataType::FixedSizeBinary(UUID_LENGTH as i32),
true,
)
.with_metadata(HashMap::from([(
EXTENSION_NAME_KEY.to_string(),
UUID_EXTENSION_NAME.to_string(),
)]))]);

let uuids = vec![
Uuid::from_u128(0xa1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8u128),
Uuid::from_u128(0xb1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8u128),
Uuid::from_u128(0),
Uuid::from_u128(0x41),
Uuid::from_u128(0x42),
];
let buf = uuids
.iter()
.flat_map(|uuid| uuid.as_bytes())
.copied()
.collect::<Vec<u8>>();
let buf = Buffer::from_vec(buf);
let rb = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(FixedSizeBinaryArray::new(
UUID_LENGTH as i32,
buf,
Some(NullBuffer::from(vec![true, true, false, true, true])),
))],
)?;
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("SELECT a, typeof(a) FROM arrow(?, ?)")?;
let mut arr = stmt.query_arrow(param)?;
let rb = arr.next().expect("no record batch");
let rb = [rb];
let printed = pretty::pretty_format_batches(&rb).unwrap();
assert_eq!(
"\
+--------------------------------------+-----------+
| a | typeof(a) |
+--------------------------------------+-----------+
| a1a2a3a4-b1b2-c1c2-d1d2-d3d4d5d6d7d8 | UUID |
| b1a2a3a4-b1b2-c1c2-d1d2-d3d4d5d6d7d8 | UUID |
| | UUID |
| 00000000-0000-0000-0000-000000000041 | UUID |
| 00000000-0000-0000-0000-000000000042 | UUID |
+--------------------------------------+-----------+",
printed.to_string(),
"{printed}"
);
Ok(())
}
}
Loading