diff --git a/crates/sqlbuiltins/src/functions/mod.rs b/crates/sqlbuiltins/src/functions/mod.rs index 60c2dbb11..4de99606c 100644 --- a/crates/sqlbuiltins/src/functions/mod.rs +++ b/crates/sqlbuiltins/src/functions/mod.rs @@ -11,7 +11,7 @@ use once_cell::sync::Lazy; use protogen::metastore::types::catalog::{EntryMeta, EntryType, FunctionEntry, FunctionType}; use scalars::df_scalars::ArrowCastFunction; -use scalars::hashing::{FnvHash, SipHash}; +use scalars::hashing::{FnvHash, PartitionResults, SipHash}; use scalars::kdl::{KDLMatches, KDLSelect}; use scalars::postgres::*; use scalars::{ConnectionId, Version}; @@ -192,9 +192,10 @@ impl FunctionRegistry { // KDL functions Arc::new(KDLMatches), Arc::new(KDLSelect), - // Hashing/Sharding + // Hashing/Partitioning Arc::new(SipHash), Arc::new(FnvHash), + Arc::new(PartitionResults), ]; let udfs = udfs .into_iter() diff --git a/crates/sqlbuiltins/src/functions/scalars/hashing.rs b/crates/sqlbuiltins/src/functions/scalars/hashing.rs index 68e31d9b8..19123b713 100644 --- a/crates/sqlbuiltins/src/functions/scalars/hashing.rs +++ b/crates/sqlbuiltins/src/functions/scalars/hashing.rs @@ -63,6 +63,7 @@ impl ConstBuiltinFunction for FnvHash { )) } } + impl BuiltinScalarUDF for FnvHash { fn as_expr(&self, args: Vec) -> Expr { let udf = ScalarUDF { @@ -86,3 +87,69 @@ impl BuiltinScalarUDF for FnvHash { )) } } + +pub struct PartitionResults; + +impl ConstBuiltinFunction for PartitionResults { + const NAME: &'static str = "partition_results"; + const DESCRIPTION: &'static str = + "Returns true if the value is in the partition ID given the number of partitions."; + const EXAMPLE: &'static str = "partition_results(, , )"; + const FUNCTION_TYPE: FunctionType = FunctionType::Scalar; + + fn signature(&self) -> Option { + Some(Signature::new( + // args: , , + TypeSignature::Any(3), + Volatility::Immutable, + )) + } +} + +impl BuiltinScalarUDF for PartitionResults { + fn as_expr(&self, args: Vec) -> Expr { + let udf = ScalarUDF { + name: Self::NAME.to_string(), + signature: ConstBuiltinFunction::signature(self).unwrap(), + return_type: Arc::new(|_| Ok(Arc::new(DataType::Utf8))), + fun: Arc::new(move |input| { + if input.len() != 3 { + return Err(DataFusionError::Execution( + "must specify exactly three arguments".to_string(), + )); + } + + let num_partitions = get_nth_u64_fn_arg(input, 1)?; + let partition_id = get_nth_u64_fn_arg(input, 2)?; + + if partition_id >= num_partitions { + return Err(DataFusionError::Execution( + format!( + "id {} must be less than number of partitions {}", + partition_id, num_partitions, + ) + .to_string(), + )); + } + + // hash at the end once the other arguments are + // validated because the hashing is potentially the + // expensive part + Ok(get_nth_scalar_value(input, 0, &|value| -> Result< + ScalarValue, + BuiltinError, + > { + let mut hasher = FnvHasher::default(); + value.hash(&mut hasher); + Ok(ScalarValue::Boolean(Some( + hasher.finish() % num_partitions == partition_id, + ))) + })?) + }), + }; + Expr::ScalarUDF(datafusion::logical_expr::expr::ScalarUDF::new( + Arc::new(udf), + args, + )) + } +} diff --git a/testdata/sqllogictests/functions/hashing.slt b/testdata/sqllogictests/functions/hashing.slt index d66ef0012..e5d217c98 100644 --- a/testdata/sqllogictests/functions/hashing.slt +++ b/testdata/sqllogictests/functions/hashing.slt @@ -106,3 +106,62 @@ query I select siphash('42'); ---- 8771948186893062792 + +######################################################################## +# +# partition_results(, , ) +# +######################################################################## + +statement error +select partition_results(); + +statement error +select partition_results('buddy', 100, 2, 3); + +statement error +select partition_results('buddy', -100, -2); + +statement error +select partition_results('buddy', 100, 200); + +statement error +select partition_results('9001', '100', '10'); + +statement error +select partition_results(9001, 100, '10'); + +statement error +select partition_results(9001, '100', 10); + +statement ok +select partition_results(100, 10, 0); + +statement ok +select partition_results(100, 10.0, 1.0); + +statement error +select partition_results(100, 10.5, 1.5); + +statement error +select partition_results(16, 4, 4); + +query B +select partition_results(16, 4, 0); +---- +t + +query B +select partition_results(16, 4, 1); +---- +f + +query B +select partition_results(16, 4, 2); +---- +f + +query B +select partition_results(16, 4, 3); +---- +f