|
16 | 16 | // under the License. |
17 | 17 |
|
18 | 18 | use std::any::Any; |
| 19 | +use std::collections::HashMap; |
19 | 20 | use std::hash::{DefaultHasher, Hash, Hasher}; |
20 | 21 | use std::sync::Arc; |
21 | 22 |
|
| 23 | +use arrow::array::as_string_array; |
22 | 24 | use arrow::compute::kernels::numeric::add; |
23 | 25 | use arrow_array::builder::BooleanBuilder; |
24 | 26 | use arrow_array::cast::AsArray; |
@@ -483,6 +485,185 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { |
483 | 485 | Ok(()) |
484 | 486 | } |
485 | 487 |
|
| 488 | +/// Volatile UDF that should append a different value to each row |
| 489 | +#[derive(Debug)] |
| 490 | +struct AddIndexToStringVolatileScalarUDF { |
| 491 | + name: String, |
| 492 | + signature: Signature, |
| 493 | + return_type: DataType, |
| 494 | +} |
| 495 | + |
| 496 | +impl AddIndexToStringVolatileScalarUDF { |
| 497 | + fn new() -> Self { |
| 498 | + Self { |
| 499 | + name: "add_index_to_string".to_string(), |
| 500 | + signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), |
| 501 | + return_type: DataType::Utf8, |
| 502 | + } |
| 503 | + } |
| 504 | +} |
| 505 | + |
| 506 | +impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF { |
| 507 | + fn as_any(&self) -> &dyn Any { |
| 508 | + self |
| 509 | + } |
| 510 | + |
| 511 | + fn name(&self) -> &str { |
| 512 | + &self.name |
| 513 | + } |
| 514 | + |
| 515 | + fn signature(&self) -> &Signature { |
| 516 | + &self.signature |
| 517 | + } |
| 518 | + |
| 519 | + fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
| 520 | + Ok(self.return_type.clone()) |
| 521 | + } |
| 522 | + |
| 523 | + fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> { |
| 524 | + not_impl_err!("index_with_offset function does not accept arguments") |
| 525 | + } |
| 526 | + |
| 527 | + fn invoke_batch( |
| 528 | + &self, |
| 529 | + args: &[ColumnarValue], |
| 530 | + number_rows: usize, |
| 531 | + ) -> Result<ColumnarValue> { |
| 532 | + let answer = match &args[0] { |
| 533 | + // When called with static arguments, the result is returned as an array. |
| 534 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => { |
| 535 | + let mut answer = vec![]; |
| 536 | + for index in 1..=number_rows { |
| 537 | + // When calling a function with immutable arguments, the result is returned with ")". |
| 538 | + // Example: SELECT add_index_to_string('const_value') FROM table; |
| 539 | + answer.push(index.to_string() + ") " + value); |
| 540 | + } |
| 541 | + answer |
| 542 | + } |
| 543 | + // The result is returned as an array when called with dynamic arguments. |
| 544 | + ColumnarValue::Array(array) => { |
| 545 | + let string_array = as_string_array(array); |
| 546 | + let mut counter = HashMap::<&str, u64>::new(); |
| 547 | + string_array |
| 548 | + .iter() |
| 549 | + .map(|value| { |
| 550 | + let value = value.expect("Unexpected null"); |
| 551 | + let index = counter.get(value).unwrap_or(&0) + 1; |
| 552 | + counter.insert(value, index); |
| 553 | + |
| 554 | + // When calling a function with mutable arguments, the result is returned with ".". |
| 555 | + // Example: SELECT add_index_to_string(table.value) FROM table; |
| 556 | + index.to_string() + ". " + value |
| 557 | + }) |
| 558 | + .collect() |
| 559 | + } |
| 560 | + _ => unimplemented!(), |
| 561 | + }; |
| 562 | + Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer)))) |
| 563 | + } |
| 564 | +} |
| 565 | + |
| 566 | +#[tokio::test] |
| 567 | +async fn volatile_scalar_udf_with_params() -> Result<()> { |
| 568 | + { |
| 569 | + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); |
| 570 | + |
| 571 | + let batch = RecordBatch::try_new( |
| 572 | + Arc::new(schema.clone()), |
| 573 | + vec![Arc::new(StringArray::from(vec![ |
| 574 | + "test_1", "test_1", "test_1", "test_2", "test_2", "test_1", "test_2", |
| 575 | + ]))], |
| 576 | + )?; |
| 577 | + let ctx = SessionContext::new(); |
| 578 | + |
| 579 | + ctx.register_batch("t", batch)?; |
| 580 | + |
| 581 | + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); |
| 582 | + |
| 583 | + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); |
| 584 | + |
| 585 | + let result = |
| 586 | + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") // with dynamic function parameters |
| 587 | + .await?; |
| 588 | + let expected = [ |
| 589 | + "+-----------+", |
| 590 | + "| str |", |
| 591 | + "+-----------+", |
| 592 | + "| 1. test_1 |", |
| 593 | + "| 2. test_1 |", |
| 594 | + "| 3. test_1 |", |
| 595 | + "| 1. test_2 |", |
| 596 | + "| 2. test_2 |", |
| 597 | + "| 4. test_1 |", |
| 598 | + "| 3. test_2 |", |
| 599 | + "+-----------+", |
| 600 | + ]; |
| 601 | + assert_batches_eq!(expected, &result); |
| 602 | + |
| 603 | + let result = |
| 604 | + plan_and_collect(&ctx, "select add_index_to_string('test') AS str from t") // with fixed function parameters |
| 605 | + .await?; |
| 606 | + let expected = [ |
| 607 | + "+---------+", |
| 608 | + "| str |", |
| 609 | + "+---------+", |
| 610 | + "| 1) test |", |
| 611 | + "| 2) test |", |
| 612 | + "| 3) test |", |
| 613 | + "| 4) test |", |
| 614 | + "| 5) test |", |
| 615 | + "| 6) test |", |
| 616 | + "| 7) test |", |
| 617 | + "+---------+", |
| 618 | + ]; |
| 619 | + assert_batches_eq!(expected, &result); |
| 620 | + |
| 621 | + let result = |
| 622 | + plan_and_collect(&ctx, "select add_index_to_string('test_value') as str") // with fixed function parameters |
| 623 | + .await?; |
| 624 | + let expected = [ |
| 625 | + "+---------------+", |
| 626 | + "| str |", |
| 627 | + "+---------------+", |
| 628 | + "| 1) test_value |", |
| 629 | + "+---------------+", |
| 630 | + ]; |
| 631 | + assert_batches_eq!(expected, &result); |
| 632 | + } |
| 633 | + { |
| 634 | + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); |
| 635 | + |
| 636 | + let batch = RecordBatch::try_new( |
| 637 | + Arc::new(schema.clone()), |
| 638 | + vec![Arc::new(StringArray::from(vec![ |
| 639 | + "test_1", "test_1", "test_1", |
| 640 | + ]))], |
| 641 | + )?; |
| 642 | + let ctx = SessionContext::new(); |
| 643 | + |
| 644 | + ctx.register_batch("t", batch)?; |
| 645 | + |
| 646 | + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); |
| 647 | + |
| 648 | + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); |
| 649 | + |
| 650 | + let result = |
| 651 | + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") |
| 652 | + .await?; |
| 653 | + let expected = [ |
| 654 | + "+-----------+", // |
| 655 | + "| str |", // |
| 656 | + "+-----------+", // |
| 657 | + "| 1. test_1 |", // |
| 658 | + "| 2. test_1 |", // |
| 659 | + "| 3. test_1 |", // |
| 660 | + "+-----------+", |
| 661 | + ]; |
| 662 | + assert_batches_eq!(expected, &result); |
| 663 | + } |
| 664 | + Ok(()) |
| 665 | +} |
| 666 | + |
486 | 667 | #[derive(Debug)] |
487 | 668 | struct CastToI64UDF { |
488 | 669 | signature: Signature, |
|
0 commit comments