Skip to content

Commit ca4b62d

Browse files
agscppAgaev Huseyn
authored andcommitted
Fix functions with Volatility::Volatile and parameters (apache#13001)
Co-authored-by: Agaev Huseyn <h.agaev@vkteam.ru>
1 parent d53f727 commit ca4b62d

File tree

3 files changed

+212
-5
lines changed

3 files changed

+212
-5
lines changed

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
// under the License.
1717

1818
use std::any::Any;
19+
use std::collections::HashMap;
1920
use std::hash::{DefaultHasher, Hash, Hasher};
2021
use std::sync::Arc;
2122

23+
use arrow::array::as_string_array;
2224
use arrow::compute::kernels::numeric::add;
2325
use arrow_array::builder::BooleanBuilder;
2426
use arrow_array::cast::AsArray;
@@ -483,6 +485,185 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
483485
Ok(())
484486
}
485487

488+
/// Volatile UDF that should append a different value to each row
489+
#[derive(Debug)]
490+
struct AddIndexToStringVolatileScalarUDF {
491+
name: String,
492+
signature: Signature,
493+
return_type: DataType,
494+
}
495+
496+
impl AddIndexToStringVolatileScalarUDF {
497+
fn new() -> Self {
498+
Self {
499+
name: "add_index_to_string".to_string(),
500+
signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile),
501+
return_type: DataType::Utf8,
502+
}
503+
}
504+
}
505+
506+
impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
507+
fn as_any(&self) -> &dyn Any {
508+
self
509+
}
510+
511+
fn name(&self) -> &str {
512+
&self.name
513+
}
514+
515+
fn signature(&self) -> &Signature {
516+
&self.signature
517+
}
518+
519+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
520+
Ok(self.return_type.clone())
521+
}
522+
523+
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
524+
not_impl_err!("index_with_offset function does not accept arguments")
525+
}
526+
527+
fn invoke_batch(
528+
&self,
529+
args: &[ColumnarValue],
530+
number_rows: usize,
531+
) -> Result<ColumnarValue> {
532+
let answer = match &args[0] {
533+
// When called with static arguments, the result is returned as an array.
534+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => {
535+
let mut answer = vec![];
536+
for index in 1..=number_rows {
537+
// When calling a function with immutable arguments, the result is returned with ")".
538+
// Example: SELECT add_index_to_string('const_value') FROM table;
539+
answer.push(index.to_string() + ") " + value);
540+
}
541+
answer
542+
}
543+
// The result is returned as an array when called with dynamic arguments.
544+
ColumnarValue::Array(array) => {
545+
let string_array = as_string_array(array);
546+
let mut counter = HashMap::<&str, u64>::new();
547+
string_array
548+
.iter()
549+
.map(|value| {
550+
let value = value.expect("Unexpected null");
551+
let index = counter.get(value).unwrap_or(&0) + 1;
552+
counter.insert(value, index);
553+
554+
// When calling a function with mutable arguments, the result is returned with ".".
555+
// Example: SELECT add_index_to_string(table.value) FROM table;
556+
index.to_string() + ". " + value
557+
})
558+
.collect()
559+
}
560+
_ => unimplemented!(),
561+
};
562+
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
563+
}
564+
}
565+
566+
#[tokio::test]
567+
async fn volatile_scalar_udf_with_params() -> Result<()> {
568+
{
569+
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
570+
571+
let batch = RecordBatch::try_new(
572+
Arc::new(schema.clone()),
573+
vec![Arc::new(StringArray::from(vec![
574+
"test_1", "test_1", "test_1", "test_2", "test_2", "test_1", "test_2",
575+
]))],
576+
)?;
577+
let ctx = SessionContext::new();
578+
579+
ctx.register_batch("t", batch)?;
580+
581+
let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new();
582+
583+
ctx.register_udf(ScalarUDF::from(get_new_str_udf));
584+
585+
let result =
586+
plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") // with dynamic function parameters
587+
.await?;
588+
let expected = [
589+
"+-----------+",
590+
"| str |",
591+
"+-----------+",
592+
"| 1. test_1 |",
593+
"| 2. test_1 |",
594+
"| 3. test_1 |",
595+
"| 1. test_2 |",
596+
"| 2. test_2 |",
597+
"| 4. test_1 |",
598+
"| 3. test_2 |",
599+
"+-----------+",
600+
];
601+
assert_batches_eq!(expected, &result);
602+
603+
let result =
604+
plan_and_collect(&ctx, "select add_index_to_string('test') AS str from t") // with fixed function parameters
605+
.await?;
606+
let expected = [
607+
"+---------+",
608+
"| str |",
609+
"+---------+",
610+
"| 1) test |",
611+
"| 2) test |",
612+
"| 3) test |",
613+
"| 4) test |",
614+
"| 5) test |",
615+
"| 6) test |",
616+
"| 7) test |",
617+
"+---------+",
618+
];
619+
assert_batches_eq!(expected, &result);
620+
621+
let result =
622+
plan_and_collect(&ctx, "select add_index_to_string('test_value') as str") // with fixed function parameters
623+
.await?;
624+
let expected = [
625+
"+---------------+",
626+
"| str |",
627+
"+---------------+",
628+
"| 1) test_value |",
629+
"+---------------+",
630+
];
631+
assert_batches_eq!(expected, &result);
632+
}
633+
{
634+
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
635+
636+
let batch = RecordBatch::try_new(
637+
Arc::new(schema.clone()),
638+
vec![Arc::new(StringArray::from(vec![
639+
"test_1", "test_1", "test_1",
640+
]))],
641+
)?;
642+
let ctx = SessionContext::new();
643+
644+
ctx.register_batch("t", batch)?;
645+
646+
let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new();
647+
648+
ctx.register_udf(ScalarUDF::from(get_new_str_udf));
649+
650+
let result =
651+
plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t")
652+
.await?;
653+
let expected = [
654+
"+-----------+", //
655+
"| str |", //
656+
"+-----------+", //
657+
"| 1. test_1 |", //
658+
"| 2. test_1 |", //
659+
"| 3. test_1 |", //
660+
"+-----------+",
661+
];
662+
assert_batches_eq!(expected, &result);
663+
}
664+
Ok(())
665+
}
666+
486667
#[derive(Debug)]
487668
struct CastToI64UDF {
488669
signature: Signature,

datafusion/expr/src/udf.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,17 @@ impl ScalarUDF {
209209
self.inner.is_nullable(args, schema)
210210
}
211211

212+
/// Invoke the function with `args` and number of rows, returning the appropriate result.
213+
///
214+
/// See [`ScalarUDFImpl::invoke_batch`] for more details.
215+
pub fn invoke_batch(
216+
&self,
217+
args: &[ColumnarValue],
218+
number_rows: usize,
219+
) -> Result<ColumnarValue> {
220+
self.inner.invoke_batch(args, number_rows)
221+
}
222+
212223
/// Invoke the function without `args` but number of rows, returning the appropriate result.
213224
///
214225
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
@@ -446,7 +457,25 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
446457
/// to arrays, which will likely be simpler code, but be slower.
447458
///
448459
/// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
449-
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue>;
460+
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
461+
not_impl_err!(
462+
"Function {} does not implement invoke but called",
463+
self.name()
464+
)
465+
}
466+
467+
/// Invoke the function with `args` and the number of rows,
468+
/// returning the appropriate result.
469+
fn invoke_batch(
470+
&self,
471+
args: &[ColumnarValue],
472+
number_rows: usize,
473+
) -> Result<ColumnarValue> {
474+
match args.is_empty() {
475+
true => self.invoke_no_args(number_rows),
476+
false => self.invoke(args),
477+
}
478+
}
450479

451480
/// Invoke the function without `args`, instead the number of rows are provided,
452481
/// returning the appropriate result.

datafusion/physical-expr/src/scalar_function.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,7 @@ impl PhysicalExpr for ScalarFunctionExpr {
140140
.collect::<Result<Vec<_>>>()?;
141141

142142
// evaluate the function
143-
let output = match self.args.is_empty() {
144-
true => self.fun.invoke_no_args(batch.num_rows()),
145-
false => self.fun.invoke(&inputs),
146-
}?;
143+
let output = self.fun.invoke_batch(&inputs, batch.num_rows())?;
147144

148145
if let ColumnarValue::Array(array) = &output {
149146
if array.len() != batch.num_rows() {

0 commit comments

Comments
 (0)