Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

respect offset in utf8 and list casts #335

Merged
merged 1 commit into from
May 24, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,7 @@ where
};

let mut builder = ArrayData::builder(dtype)
.offset(array.offset())
.len(array.len())
.add_buffer(offset_buffer)
.add_buffer(str_values_buf);
Expand Down Expand Up @@ -1744,7 +1745,12 @@ where
_ => unreachable!(),
};

let offsets = data.buffer::<OffsetSizeFrom>(0);
// Safety:
// The first buffer is the offsets and they are aligned to OffSetSizeFrom: (i64 or i32)
// Justification:
// The safe variant data.buffer::<OffsetSizeFrom> take the offset into account and we
// cannot create a list array with offsets starting at non zero.
let offsets = unsafe { data.buffers()[0].as_slice().align_to::<OffsetSizeFrom>() }.1;

let iter = offsets.iter().map(|idx| {
let idx: OffsetSizeTo = NumCast::from(*idx).unwrap();
Expand All @@ -1757,6 +1763,7 @@ where

// wrap up
let mut builder = ArrayData::builder(out_dtype)
.offset(array.offset())
.len(array.len())
.add_buffer(offset_buffer)
.add_child_data(value_data);
Expand Down Expand Up @@ -3840,4 +3847,30 @@ mod tests {
Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
]
}

#[test]
fn test_utf8_cast_offsets() {
// test if offset of the array is taken into account during cast
let str_array = StringArray::from(vec!["a", "b", "c"]);
let str_array = str_array.slice(1, 2);

let out = cast(&str_array, &DataType::LargeUtf8).unwrap();

let large_str_array = out.as_any().downcast_ref::<LargeStringArray>().unwrap();
let strs = large_str_array.into_iter().flatten().collect::<Vec<_>>();
assert_eq!(strs, &["b", "c"])
}

#[test]
fn test_list_cast_offsets() {
// test if offset of the array is taken into account during cast
let array1 = make_list_array().slice(1, 2);
let array2 = Arc::new(make_list_array()) as ArrayRef;

let dt = DataType::LargeList(Box::new(Field::new("item", DataType::Int32, true)));
let out1 = cast(&array1, &dt).unwrap();
let out2 = cast(&array2, &dt).unwrap();

assert_eq!(&out1, &out2.slice(1, 2))
}
}