Skip to content

Commit

Permalink
feat: partition results helper (#2299)
Browse files Browse the repository at this point in the history
Adds partitioning (sharding) of result sets using the hashing method. 

Closes #2220
  • Loading branch information
tychoish authored Dec 28, 2023
1 parent 026bacd commit ee4b440
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 2 deletions.
5 changes: 3 additions & 2 deletions crates/sqlbuiltins/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use once_cell::sync::Lazy;

use protogen::metastore::types::catalog::{EntryMeta, EntryType, FunctionEntry, FunctionType};
use scalars::df_scalars::ArrowCastFunction;
use scalars::hashing::{FnvHash, SipHash};
use scalars::hashing::{FnvHash, PartitionResults, SipHash};
use scalars::kdl::{KDLMatches, KDLSelect};
use scalars::postgres::*;
use scalars::{ConnectionId, Version};
Expand Down Expand Up @@ -192,9 +192,10 @@ impl FunctionRegistry {
// KDL functions
Arc::new(KDLMatches),
Arc::new(KDLSelect),
// Hashing/Sharding
// Hashing/Partitioning
Arc::new(SipHash),
Arc::new(FnvHash),
Arc::new(PartitionResults),
];
let udfs = udfs
.into_iter()
Expand Down
67 changes: 67 additions & 0 deletions crates/sqlbuiltins/src/functions/scalars/hashing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ impl ConstBuiltinFunction for FnvHash {
))
}
}

impl BuiltinScalarUDF for FnvHash {
fn as_expr(&self, args: Vec<Expr>) -> Expr {
let udf = ScalarUDF {
Expand All @@ -86,3 +87,69 @@ impl BuiltinScalarUDF for FnvHash {
))
}
}

pub struct PartitionResults;

impl ConstBuiltinFunction for PartitionResults {
const NAME: &'static str = "partition_results";
const DESCRIPTION: &'static str =
"Returns true if the value is in the partition ID given the number of partitions.";
const EXAMPLE: &'static str = "partition_results(<value>, <num_partitions>, <partition_id>)";
const FUNCTION_TYPE: FunctionType = FunctionType::Scalar;

fn signature(&self) -> Option<Signature> {
Some(Signature::new(
// args: <FIELD>, <num_partitions>, <partition_id>
TypeSignature::Any(3),
Volatility::Immutable,
))
}
}

impl BuiltinScalarUDF for PartitionResults {
fn as_expr(&self, args: Vec<Expr>) -> Expr {
let udf = ScalarUDF {
name: Self::NAME.to_string(),
signature: ConstBuiltinFunction::signature(self).unwrap(),
return_type: Arc::new(|_| Ok(Arc::new(DataType::Utf8))),
fun: Arc::new(move |input| {
if input.len() != 3 {
return Err(DataFusionError::Execution(
"must specify exactly three arguments".to_string(),
));
}

let num_partitions = get_nth_u64_fn_arg(input, 1)?;
let partition_id = get_nth_u64_fn_arg(input, 2)?;

if partition_id >= num_partitions {
return Err(DataFusionError::Execution(
format!(
"id {} must be less than number of partitions {}",
partition_id, num_partitions,
)
.to_string(),
));
}

// hash at the end once the other arguments are
// validated because the hashing is potentially the
// expensive part
Ok(get_nth_scalar_value(input, 0, &|value| -> Result<
ScalarValue,
BuiltinError,
> {
let mut hasher = FnvHasher::default();
value.hash(&mut hasher);
Ok(ScalarValue::Boolean(Some(
hasher.finish() % num_partitions == partition_id,
)))
})?)
}),
};
Expr::ScalarUDF(datafusion::logical_expr::expr::ScalarUDF::new(
Arc::new(udf),
args,
))
}
}
59 changes: 59 additions & 0 deletions testdata/sqllogictests/functions/hashing.slt
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,62 @@ query I
select siphash('42');
----
8771948186893062792

########################################################################
#
# partition_results(<value>, <num_shards>, <shard_id>)
#
########################################################################

statement error
select partition_results();

statement error
select partition_results('buddy', 100, 2, 3);

statement error
select partition_results('buddy', -100, -2);

statement error
select partition_results('buddy', 100, 200);

statement error
select partition_results('9001', '100', '10');

statement error
select partition_results(9001, 100, '10');

statement error
select partition_results(9001, '100', 10);

statement ok
select partition_results(100, 10, 0);

statement ok
select partition_results(100, 10.0, 1.0);

statement error
select partition_results(100, 10.5, 1.5);

statement error
select partition_results(16, 4, 4);

query B
select partition_results(16, 4, 0);
----
t

query B
select partition_results(16, 4, 1);
----
f

query B
select partition_results(16, 4, 2);
----
f

query B
select partition_results(16, 4, 3);
----
f

0 comments on commit ee4b440

Please sign in to comment.