|
17 | 17 |
|
18 | 18 | //! [`zip`]: Combine values from two arrays based on boolean mask |
19 | 19 |
|
20 | | -use crate::filter::SlicesIterator; |
| 20 | +use crate::filter::{SlicesIterator, prep_null_mask_filter}; |
21 | 21 | use arrow_array::*; |
| 22 | +use arrow_buffer::BooleanBuffer; |
22 | 23 | use arrow_data::transform::MutableArrayData; |
23 | 24 | use arrow_schema::ArrowError; |
24 | 25 |
|
@@ -127,7 +128,8 @@ pub fn zip( |
127 | 128 | // keep track of how much is filled |
128 | 129 | let mut filled = 0; |
129 | 130 |
|
130 | | - SlicesIterator::new(mask).for_each(|(start, end)| { |
| 131 | + let mask = maybe_prep_null_mask_filter(mask); |
| 132 | + SlicesIterator::from(&mask).for_each(|(start, end)| { |
131 | 133 | // the gap needs to be filled with falsy values |
132 | 134 | if start > filled { |
133 | 135 | if falsy_is_scalar { |
@@ -166,9 +168,22 @@ pub fn zip( |
166 | 168 | Ok(make_array(data)) |
167 | 169 | } |
168 | 170 |
|
| 171 | +fn maybe_prep_null_mask_filter(predicate: &BooleanArray) -> BooleanBuffer { |
| 172 | + // Nulls are treated as false |
| 173 | + if predicate.null_count() == 0 { |
| 174 | + predicate.values().clone() |
| 175 | + } else { |
| 176 | + let cleaned = prep_null_mask_filter(predicate); |
| 177 | + let (boolean_buffer, _) = cleaned.into_parts(); |
| 178 | + boolean_buffer |
| 179 | + } |
| 180 | +} |
| 181 | + |
169 | 182 | #[cfg(test)] |
170 | 183 | mod test { |
171 | 184 | use super::*; |
| 185 | + use arrow_array::cast::AsArray; |
| 186 | + use arrow_buffer::{BooleanBuffer, NullBuffer}; |
172 | 187 |
|
173 | 188 | #[test] |
174 | 189 | fn test_zip_kernel_one() { |
@@ -279,4 +294,110 @@ mod test { |
279 | 294 | let expected = Int32Array::from(vec![None, None, Some(42), Some(42), None]); |
280 | 295 | assert_eq!(actual, &expected); |
281 | 296 | } |
| 297 | + |
| 298 | + #[test] |
| 299 | + fn test_zip_primitive_array_with_nulls_is_mask_should_be_treated_as_false() { |
| 300 | + let truthy = Int32Array::from_iter_values(vec![1, 2, 3, 4, 5, 6]); |
| 301 | + let falsy = Int32Array::from_iter_values(vec![7, 8, 9, 10, 11, 12]); |
| 302 | + |
| 303 | + let mask = { |
| 304 | + let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]); |
| 305 | + let nulls = NullBuffer::from(vec![ |
| 306 | + true, true, true, |
| 307 | + false, // null treated as false even though in the original mask it was true |
| 308 | + true, true, |
| 309 | + ]); |
| 310 | + BooleanArray::new(booleans, Some(nulls)) |
| 311 | + }; |
| 312 | + let out = zip(&mask, &truthy, &falsy).unwrap(); |
| 313 | + let actual = out.as_any().downcast_ref::<Int32Array>().unwrap(); |
| 314 | + let expected = Int32Array::from(vec![ |
| 315 | + Some(1), |
| 316 | + Some(2), |
| 317 | + Some(9), |
| 318 | + Some(10), // true in mask but null |
| 319 | + Some(11), |
| 320 | + Some(12), |
| 321 | + ]); |
| 322 | + assert_eq!(actual, &expected); |
| 323 | + } |
| 324 | + |
| 325 | + #[test] |
| 326 | + fn test_zip_kernel_primitive_scalar_with_boolean_array_mask_with_nulls_should_be_treated_as_false() |
| 327 | + { |
| 328 | + let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1)); |
| 329 | + let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1)); |
| 330 | + |
| 331 | + let mask = { |
| 332 | + let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]); |
| 333 | + let nulls = NullBuffer::from(vec![ |
| 334 | + true, true, true, |
| 335 | + false, // null treated as false even though in the original mask it was true |
| 336 | + true, true, |
| 337 | + ]); |
| 338 | + BooleanArray::new(booleans, Some(nulls)) |
| 339 | + }; |
| 340 | + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
| 341 | + let actual = out.as_any().downcast_ref::<Int32Array>().unwrap(); |
| 342 | + let expected = Int32Array::from(vec![ |
| 343 | + Some(42), |
| 344 | + Some(42), |
| 345 | + Some(123), |
| 346 | + Some(123), // true in mask but null |
| 347 | + Some(123), |
| 348 | + Some(123), |
| 349 | + ]); |
| 350 | + assert_eq!(actual, &expected); |
| 351 | + } |
| 352 | + |
| 353 | + #[test] |
| 354 | + fn test_zip_string_array_with_nulls_is_mask_should_be_treated_as_false() { |
| 355 | + let truthy = StringArray::from_iter_values(vec!["1", "2", "3", "4", "5", "6"]); |
| 356 | + let falsy = StringArray::from_iter_values(vec!["7", "8", "9", "10", "11", "12"]); |
| 357 | + |
| 358 | + let mask = { |
| 359 | + let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]); |
| 360 | + let nulls = NullBuffer::from(vec![ |
| 361 | + true, true, true, |
| 362 | + false, // null treated as false even though in the original mask it was true |
| 363 | + true, true, |
| 364 | + ]); |
| 365 | + BooleanArray::new(booleans, Some(nulls)) |
| 366 | + }; |
| 367 | + let out = zip(&mask, &truthy, &falsy).unwrap(); |
| 368 | + let actual = out.as_string::<i32>(); |
| 369 | + let expected = StringArray::from_iter_values(vec![ |
| 370 | + "1", "2", "9", "10", // true in mask but null |
| 371 | + "11", "12", |
| 372 | + ]); |
| 373 | + assert_eq!(actual, &expected); |
| 374 | + } |
| 375 | + |
| 376 | + #[test] |
| 377 | + fn test_zip_kernel_large_string_scalar_with_boolean_array_mask_with_nulls_should_be_treated_as_false() |
| 378 | + { |
| 379 | + let scalar_truthy = Scalar::new(LargeStringArray::from_iter_values(["test"])); |
| 380 | + let scalar_falsy = Scalar::new(LargeStringArray::from_iter_values(["something else"])); |
| 381 | + |
| 382 | + let mask = { |
| 383 | + let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]); |
| 384 | + let nulls = NullBuffer::from(vec![ |
| 385 | + true, true, true, |
| 386 | + false, // null treated as false even though in the original mask it was true |
| 387 | + true, true, |
| 388 | + ]); |
| 389 | + BooleanArray::new(booleans, Some(nulls)) |
| 390 | + }; |
| 391 | + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
| 392 | + let actual = out.as_any().downcast_ref::<LargeStringArray>().unwrap(); |
| 393 | + let expected = LargeStringArray::from_iter(vec![ |
| 394 | + Some("test"), |
| 395 | + Some("test"), |
| 396 | + Some("something else"), |
| 397 | + Some("something else"), // true in mask but null |
| 398 | + Some("something else"), |
| 399 | + Some("something else"), |
| 400 | + ]); |
| 401 | + assert_eq!(actual, &expected); |
| 402 | + } |
282 | 403 | } |
0 commit comments