Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 224 additions & 39 deletions crates/core-executor/src/duckdb/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,76 @@ use crate::error::{self as ex_error, Result as CoreResult};
use arrow_schema::{DataType, Field, FieldRef};
use datafusion::arrow::array::Array;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, TypeSignature};
use datafusion_common::config::ConfigOptions;
use datafusion_common::internal_datafusion_err;
use datafusion_expr::{ScalarUDF, ScalarUDFImpl};
use duckdb::Connection;
use duckdb::vscalar::{ArrowFunctionSignature, VArrowScalar};
use embucket_functions::string_binary::length::LengthFunc;
use embucket_functions::conditional::{
booland, boolor, boolxor, equal_null, iff, nullifzero, zeroifnull,
};
use embucket_functions::crypto::md5;
use embucket_functions::datetime::date_part_extract::Interval;
use embucket_functions::datetime::{
add_months, date_add, date_diff, date_from_parts, date_part_extract, dayname, last_day,
monthname, next_day, previous_day, time_from_parts, timestamp_from_parts,
};
use embucket_functions::numeric::div0;
use embucket_functions::regexp::{
regexp_instr, regexp_like, regexp_replace, regexp_substr, regexp_substr_all,
};
use embucket_functions::string_binary::{
hex_decode_binary, hex_decode_string, hex_encode, insert, jarowinkler_similarity as js, length,
lower, parse_ip, randstr, replace, rtrimmed_length, sha2, split, strtok, substr,
};
use embucket_functions::system::{cancel_query, typeof_func};
use snafu::ResultExt;
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::BuildHasher;
use std::{error::Error, sync::Arc};
pub struct DfUdfWrapper<T: ScalarUDFImpl> {
_inner: T,
use strum::IntoEnumIterator;

/// Generic adapter between DF and `DuckDB`
#[derive(Debug, Clone)]
pub struct DfUdfWrapper<T: ScalarUDFImpl + Default> {
_marker: std::marker::PhantomData<T>,
}

/// Stores a specific `ScalarUDF` instance for `invoke()`
#[derive(Clone)]
pub struct UdfState {
udf: Arc<ScalarUDF>,
}

impl<T: ScalarUDFImpl> DfUdfWrapper<T> {
pub const fn new(inner: T) -> Self {
Self { _inner: inner }
impl UdfState {
#[must_use]
pub const fn new(udf: Arc<ScalarUDF>) -> Self {
Self { udf }
}

#[must_use]
pub fn udf(&self) -> Arc<ScalarUDF> {
self.udf.clone()
}
}

impl Default for UdfState {
fn default() -> Self {
// dummy placeholder (will be overridden by register_with_state)
let fake = length::get_udf();
Self::new(fake)
}
}

impl<T: ScalarUDFImpl + Default> VArrowScalar for DfUdfWrapper<T> {
type State = ();
type State = UdfState;

fn invoke(_state: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn Error>> {
fn invoke(state: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn Error>> {
let num_rows = input.num_rows();
let schema = input.schema();
let func = T::default();
let func = state.udf();
let args: Vec<ColumnarValue> = input
.columns()
.iter()
Expand Down Expand Up @@ -68,38 +114,177 @@ impl<T: ScalarUDFImpl + Default> VArrowScalar for DfUdfWrapper<T> {
fn signatures() -> Vec<ArrowFunctionSignature> {
let func = T::default();
let sig = func.signature();
expand_signature(&func, &sig.type_signature)
}
}

match &sig.type_signature {
datafusion::logical_expr::TypeSignature::Exact(types) => {
vec![ArrowFunctionSignature::exact(
types.clone(),
func.return_type(types).unwrap_or(DataType::Utf8),
)]
}
datafusion::logical_expr::TypeSignature::Variadic(valid_types) => {
vec![ArrowFunctionSignature::exact(
vec![valid_types.first().cloned().unwrap_or(DataType::Utf8)],
func.return_type(&[valid_types.first().cloned().unwrap_or(DataType::Utf8)])
.unwrap_or(DataType::Utf8),
)]
}
datafusion::logical_expr::TypeSignature::Any(n) => {
let args = vec![DataType::Utf8; *n];
let ret = func.return_type(&args).unwrap_or(DataType::Utf8);
vec![ArrowFunctionSignature::exact(args, ret)]
}
_ => {
let ret = func
.return_type(&[DataType::Utf8])
.unwrap_or(DataType::Utf8);
vec![ArrowFunctionSignature::exact(vec![DataType::Utf8], ret)]
}
}
pub fn register_all_udfs<S>(
conn: &Connection,
udfs: &HashMap<String, Arc<ScalarUDF>, S>,
) -> CoreResult<Vec<String>>
where
S: BuildHasher,
{
let mut failed: Vec<String> = Vec::new();

// String binary
register_duckdb_udf::<hex_decode_binary::HexDecodeBinaryFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf_try::<hex_decode_binary::HexDecodeBinaryFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<hex_decode_string::HexDecodeStringFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf_try::<hex_decode_string::HexDecodeStringFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<hex_encode::HexEncodeFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<insert::Insert, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<js::JarowinklerSimilarityFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<length::LengthFunc, S>(conn, udfs, &mut failed)?;

register_duckdb_udf::<lower::LowerFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<parse_ip::ParseIpFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<randstr::RandStrFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<rtrimmed_length::RTrimmedLengthFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<sha2::Sha2Func, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<split::SplitFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<strtok::StrtokFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<substr::SubstrFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<replace::ReplaceFunc, S>(conn, udfs, &mut failed)?;

// Regexp
register_duckdb_udf::<regexp_instr::RegexpInstrFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<regexp_like::RegexpLikeFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<regexp_replace::RegexpReplaceFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<regexp_substr::RegexpSubstrFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<regexp_substr_all::RegexpSubstrAllFunc, S>(conn, udfs, &mut failed)?;

// Conditional
register_duckdb_udf::<booland::BoolAndFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<boolor::BoolOrFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<boolxor::BoolXorFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<equal_null::EqualNullFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<iff::IffFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<nullifzero::NullIfZeroFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<zeroifnull::ZeroIfNullFunc, S>(conn, udfs, &mut failed)?;

// Crypto
register_duckdb_udf::<md5::Md5Func, S>(conn, udfs, &mut failed)?;

// Datetime
register_duckdb_udf::<add_months::AddMonthsFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<date_add::DateAddFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<date_diff::DateDiffFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<date_from_parts::DateFromPartsFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<dayname::DayNameFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<last_day::LastDayFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<monthname::MonthNameFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<next_day::NextDayFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<previous_day::PreviousDayFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<time_from_parts::TimeFromPartsFunc, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<timestamp_from_parts::TimestampFromPartsFunc, S>(
conn,
udfs,
&mut failed,
)?;
for interval in Interval::iter() {
register_duckdb_udf_internal::<date_part_extract::DatePartExtractFunc, S>(
conn,
udfs,
&interval.to_string(),
&mut failed,
)?;
}

// Numeric
register_duckdb_udf_internal::<div0::Div0Func, S>(conn, udfs, "div0null", &mut failed)?;
register_duckdb_udf::<div0::Div0Func, S>(conn, udfs, &mut failed)?;

// System
register_duckdb_udf::<cancel_query::SystemCancelQuery, S>(conn, udfs, &mut failed)?;
register_duckdb_udf::<typeof_func::SystemTypeofFunc, S>(conn, udfs, &mut failed)?;
Ok(failed)
}

/// Registers a normal (non-try_) UDF in `DuckDB`.
pub fn register_duckdb_udf<T, S>(
conn: &Connection,
udfs: &HashMap<String, Arc<ScalarUDF>, S>,
failed: &mut Vec<String>,
) -> CoreResult<()>
where
T: ScalarUDFImpl + Default + 'static,
S: BuildHasher,
{
let name = T::default().name().to_string();
register_duckdb_udf_internal::<T, S>(conn, udfs, &name, failed)
}

/// Registers a “try_” variant of a UDF in `DuckDB`.
pub fn register_duckdb_udf_try<T, S>(
conn: &Connection,
udfs: &HashMap<String, Arc<ScalarUDF>, S>,
failed: &mut Vec<String>,
) -> CoreResult<()>
where
T: ScalarUDFImpl + Default + 'static,
S: BuildHasher,
{
let name = format!("try_{}", T::default().name());
register_duckdb_udf_internal::<T, S>(conn, udfs, &name, failed)
}

pub fn register_all_udfs(conn: &Connection) -> CoreResult<()> {
conn.register_scalar_function::<DfUdfWrapper<LengthFunc>>("length_test")
.context(ex_error::DuckdbSnafu)?;
/// Shared internal logic for both normal and try_ function registration.
fn register_duckdb_udf_internal<T, S>(
conn: &Connection,
udfs: &HashMap<String, Arc<ScalarUDF>, S>,
name: &str,
failed: &mut Vec<String>,
) -> CoreResult<()>
where
T: ScalarUDFImpl + Default + 'static,
S: BuildHasher,
{
let func = udfs
.get(name)
.ok_or_else(|| internal_datafusion_err!("Unable to find expected '{name}' function"))
.context(ex_error::DataFusionSnafu)?;

let state = UdfState::new(func.clone());
if conn
.register_scalar_function_with_state::<DfUdfWrapper<T>>(name, &state)
.is_err()
{
failed.push(name.to_string());
}
Ok(())
}

fn expand_signature<T: ScalarUDFImpl>(
func: &T,
sig: &TypeSignature,
) -> Vec<ArrowFunctionSignature> {
// DataFusion already knows all valid argument type combinations for this signature
let example_sigs = match sig {
TypeSignature::Any(arg_count) => {
let types = vec![DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
types
.into_iter()
.map(|dt| vec![dt; *arg_count])
.collect::<Vec<_>>()
}
_ => sig.get_example_types(),
};

if example_sigs.is_empty() {
// Fallback when no examples are available (e.g., for generic or nullary signatures)
let ret = func
.return_type(&[DataType::Utf8])
.unwrap_or(DataType::Utf8);
return vec![ArrowFunctionSignature::exact(vec![DataType::Utf8], ret)];
}

// Build a DuckDB signature for each valid argument combination
example_sigs
.into_iter()
.map(|types| {
let ret = func.return_type(&types).unwrap_or(DataType::Utf8);
ArrowFunctionSignature::exact(types, ret)
})
.collect()
}
9 changes: 8 additions & 1 deletion crates/core-executor/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,14 @@ impl UserQuery {
let sql = self.query.clone();

let conn = Connection::open_in_memory().context(ex_error::DuckdbSnafu)?;
register_all_udfs(&conn)?;
let failed = register_all_udfs(&conn, self.session.ctx.state().scalar_functions())?;
if !failed.is_empty() {
tracing::warn!(
"Some UDFs were not registered/overloaded in DuckDB: {:?}",
failed
);
}

apply_connection_setup_queries(&conn, &setup_queries)?;

if self.session.config.use_duck_db_explain
Expand Down
25 changes: 25 additions & 0 deletions crates/embucket-functions/src/datetime/date_part_extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use datafusion_expr::registry::FunctionRegistry;
use datafusion_expr::{ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility};
use snafu::OptionExt;
use std::any::Any;
use std::fmt;
use std::sync::Arc;
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
Expand All @@ -37,6 +38,30 @@ pub enum Interval {
Second,
}

impl fmt::Display for Interval {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::Year => "year",
Self::YearOfWeek => "yearofweek",
Self::YearOfWeekIso => "yearofweekiso",
Self::Day => "day",
Self::DayOfMonth => "dayofmonth",
Self::DayOfWeek => "dayofweek",
Self::DayOfWeekIso => "dayofweekiso",
Self::DayOfYear => "dayofyear",
Self::Week => "week",
Self::WeekOfYear => "weekofyear",
Self::WeekIso => "weekiso",
Self::Month => "month",
Self::Quarter => "quarter",
Self::Hour => "hour",
Self::Minute => "minute",
Self::Second => "second",
};
write!(f, "{s}")
}
}

/// `YEAR*` / `DAY*` / `WEEK*` / `MONTH` / `QUARTER` / `HOUR` / `MINUTE` / `SECOND` SQL function
///
/// Extracts a specific part of a date or timestamp.
Expand Down
8 changes: 4 additions & 4 deletions crates/embucket-functions/src/regexp/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pub mod errors;
pub mod regexp_instr;
mod regexp_like;
mod regexp_replace;
mod regexp_substr;
mod regexp_substr_all;
pub mod regexp_like;
pub mod regexp_replace;
pub mod regexp_substr;
pub mod regexp_substr_all;

use crate::regexp::regexp_instr::RegexpInstrFunc;
use crate::regexp::regexp_like::RegexpLikeFunc;
Expand Down
1 change: 1 addition & 0 deletions crates/embucket-functions/src/regexp/regexp_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl Default for RegexpLikeFunc {
}

impl RegexpLikeFunc {
#[must_use]
pub fn new() -> Self {
Self {
signature: Signature::one_of(
Expand Down
1 change: 1 addition & 0 deletions crates/embucket-functions/src/regexp/regexp_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ impl Default for RegexpReplaceFunc {
}

impl RegexpReplaceFunc {
#[must_use]
pub fn new() -> Self {
Self {
signature: Signature::one_of(
Expand Down
1 change: 1 addition & 0 deletions crates/embucket-functions/src/regexp/regexp_substr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ impl Default for RegexpSubstrFunc {
}

impl RegexpSubstrFunc {
#[must_use]
pub fn new() -> Self {
Self {
signature: Signature::one_of(
Expand Down
1 change: 1 addition & 0 deletions crates/embucket-functions/src/regexp/regexp_substr_all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ impl Default for RegexpSubstrAllFunc {
}

impl RegexpSubstrAllFunc {
#[must_use]
pub fn new() -> Self {
Self {
signature: Signature::one_of(
Expand Down
Loading
Loading