|
16 | 16 | // under the License. |
17 | 17 |
|
18 | 18 | use std::any::Any; |
| 19 | +use std::collections::HashMap; |
19 | 20 | use std::sync::Arc; |
20 | 21 |
|
| 22 | +use arrow::array::as_string_array; |
21 | 23 | use arrow::compute::kernels::numeric::add; |
22 | | -use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch}; |
| 24 | +use arrow_array::{ |
| 25 | + ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, |
| 26 | +}; |
23 | 27 | use arrow_schema::{DataType, Field, Schema}; |
24 | 28 | use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; |
25 | 29 | use datafusion::prelude::*; |
@@ -476,6 +480,185 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { |
476 | 480 | Ok(()) |
477 | 481 | } |
478 | 482 |
|
| 483 | +/// Volatile UDF that should append a different value to each row |
| 484 | +#[derive(Debug)] |
| 485 | +struct AddIndexToStringVolatileScalarUDF { |
| 486 | + name: String, |
| 487 | + signature: Signature, |
| 488 | + return_type: DataType, |
| 489 | +} |
| 490 | + |
| 491 | +impl AddIndexToStringVolatileScalarUDF { |
| 492 | + fn new() -> Self { |
| 493 | + Self { |
| 494 | + name: "add_index_to_string".to_string(), |
| 495 | + signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), |
| 496 | + return_type: DataType::Utf8, |
| 497 | + } |
| 498 | + } |
| 499 | +} |
| 500 | + |
| 501 | +impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF { |
| 502 | + fn as_any(&self) -> &dyn Any { |
| 503 | + self |
| 504 | + } |
| 505 | + |
| 506 | + fn name(&self) -> &str { |
| 507 | + &self.name |
| 508 | + } |
| 509 | + |
| 510 | + fn signature(&self) -> &Signature { |
| 511 | + &self.signature |
| 512 | + } |
| 513 | + |
| 514 | + fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
| 515 | + Ok(self.return_type.clone()) |
| 516 | + } |
| 517 | + |
| 518 | + fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> { |
| 519 | + not_impl_err!("index_with_offset function does not accept arguments") |
| 520 | + } |
| 521 | + |
| 522 | + fn invoke_batch( |
| 523 | + &self, |
| 524 | + args: &[ColumnarValue], |
| 525 | + number_rows: usize, |
| 526 | + ) -> Result<ColumnarValue> { |
| 527 | + let answer = match &args[0] { |
| 528 | + // When called with static arguments, the result is returned as an array. |
| 529 | + ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => { |
| 530 | + let mut answer = vec![]; |
| 531 | + for index in 1..=number_rows { |
| 532 | + // When calling a function with immutable arguments, the result is returned with ")". |
| 533 | + // Example: SELECT add_index_to_string('const_value') FROM table; |
| 534 | + answer.push(index.to_string() + ") " + value); |
| 535 | + } |
| 536 | + answer |
| 537 | + } |
| 538 | + // The result is returned as an array when called with dynamic arguments. |
| 539 | + ColumnarValue::Array(array) => { |
| 540 | + let string_array = as_string_array(array); |
| 541 | + let mut counter = HashMap::<&str, u64>::new(); |
| 542 | + string_array |
| 543 | + .iter() |
| 544 | + .map(|value| { |
| 545 | + let value = value.expect("Unexpected null"); |
| 546 | + let index = counter.get(value).unwrap_or(&0) + 1; |
| 547 | + counter.insert(value, index); |
| 548 | + |
| 549 | + // When calling a function with mutable arguments, the result is returned with ".". |
| 550 | + // Example: SELECT add_index_to_string(table.value) FROM table; |
| 551 | + index.to_string() + ". " + value |
| 552 | + }) |
| 553 | + .collect() |
| 554 | + } |
| 555 | + _ => unimplemented!(), |
| 556 | + }; |
| 557 | + Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer)))) |
| 558 | + } |
| 559 | +} |
| 560 | + |
| 561 | +#[tokio::test] |
| 562 | +async fn volatile_scalar_udf_with_params() -> Result<()> { |
| 563 | + { |
| 564 | + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); |
| 565 | + |
| 566 | + let batch = RecordBatch::try_new( |
| 567 | + Arc::new(schema.clone()), |
| 568 | + vec![Arc::new(StringArray::from(vec![ |
| 569 | + "test_1", "test_1", "test_1", "test_2", "test_2", "test_1", "test_2", |
| 570 | + ]))], |
| 571 | + )?; |
| 572 | + let ctx = SessionContext::new(); |
| 573 | + |
| 574 | + ctx.register_batch("t", batch)?; |
| 575 | + |
| 576 | + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); |
| 577 | + |
| 578 | + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); |
| 579 | + |
| 580 | + let result = |
| 581 | + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") // with dynamic function parameters |
| 582 | + .await?; |
| 583 | + let expected = [ |
| 584 | + "+-----------+", |
| 585 | + "| str |", |
| 586 | + "+-----------+", |
| 587 | + "| 1. test_1 |", |
| 588 | + "| 2. test_1 |", |
| 589 | + "| 3. test_1 |", |
| 590 | + "| 1. test_2 |", |
| 591 | + "| 2. test_2 |", |
| 592 | + "| 4. test_1 |", |
| 593 | + "| 3. test_2 |", |
| 594 | + "+-----------+", |
| 595 | + ]; |
| 596 | + assert_batches_eq!(expected, &result); |
| 597 | + |
| 598 | + let result = |
| 599 | + plan_and_collect(&ctx, "select add_index_to_string('test') AS str from t") // with fixed function parameters |
| 600 | + .await?; |
| 601 | + let expected = [ |
| 602 | + "+---------+", |
| 603 | + "| str |", |
| 604 | + "+---------+", |
| 605 | + "| 1) test |", |
| 606 | + "| 2) test |", |
| 607 | + "| 3) test |", |
| 608 | + "| 4) test |", |
| 609 | + "| 5) test |", |
| 610 | + "| 6) test |", |
| 611 | + "| 7) test |", |
| 612 | + "+---------+", |
| 613 | + ]; |
| 614 | + assert_batches_eq!(expected, &result); |
| 615 | + |
| 616 | + let result = |
| 617 | + plan_and_collect(&ctx, "select add_index_to_string('test_value') as str") // with fixed function parameters |
| 618 | + .await?; |
| 619 | + let expected = [ |
| 620 | + "+---------------+", |
| 621 | + "| str |", |
| 622 | + "+---------------+", |
| 623 | + "| 1) test_value |", |
| 624 | + "+---------------+", |
| 625 | + ]; |
| 626 | + assert_batches_eq!(expected, &result); |
| 627 | + } |
| 628 | + { |
| 629 | + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); |
| 630 | + |
| 631 | + let batch = RecordBatch::try_new( |
| 632 | + Arc::new(schema.clone()), |
| 633 | + vec![Arc::new(StringArray::from(vec![ |
| 634 | + "test_1", "test_1", "test_1", |
| 635 | + ]))], |
| 636 | + )?; |
| 637 | + let ctx = SessionContext::new(); |
| 638 | + |
| 639 | + ctx.register_batch("t", batch)?; |
| 640 | + |
| 641 | + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); |
| 642 | + |
| 643 | + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); |
| 644 | + |
| 645 | + let result = |
| 646 | + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") |
| 647 | + .await?; |
| 648 | + let expected = [ |
| 649 | + "+-----------+", // |
| 650 | + "| str |", // |
| 651 | + "+-----------+", // |
| 652 | + "| 1. test_1 |", // |
| 653 | + "| 2. test_1 |", // |
| 654 | + "| 3. test_1 |", // |
| 655 | + "+-----------+", |
| 656 | + ]; |
| 657 | + assert_batches_eq!(expected, &result); |
| 658 | + } |
| 659 | + Ok(()) |
| 660 | +} |
| 661 | + |
479 | 662 | #[derive(Debug)] |
480 | 663 | struct CastToI64UDF { |
481 | 664 | signature: Signature, |
|
0 commit comments