From 727563b2c25e26ce70aff98267cbe7d05b9c2cd7 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Fri, 22 Dec 2023 20:10:25 -0500 Subject: [PATCH 01/14] feat: add non-cryptographic hash function --- Cargo.lock | 10 +- crates/sqlbuiltins/Cargo.toml | 2 + crates/sqlbuiltins/src/functions/mod.rs | 26 +++-- .../src/functions/scalars/hashing.rs | 100 ++++++++++++++++++ .../sqlbuiltins/src/functions/scalars/mod.rs | 26 ++--- testdata/sqllogictests/functions/hashing.slt | 98 +++++++++++++++++ 6 files changed, 238 insertions(+), 24 deletions(-) create mode 100644 crates/sqlbuiltins/src/functions/scalars/hashing.rs create mode 100644 testdata/sqllogictests/functions/hashing.slt diff --git a/Cargo.lock b/Cargo.lock index eacc4bb4a..a7f1ab90b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5055,7 +5055,7 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" dependencies = [ - "siphasher", + "siphasher 0.3.11", ] [[package]] @@ -6728,6 +6728,12 @@ version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +[[package]] +name = "siphasher" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54ac45299ccbd390721be55b412d41931911f654fa99e2cb8bfb57184b2061fe" + [[package]] name = "skeptic" version = "0.13.7" @@ -6858,6 +6864,7 @@ dependencies = [ "datafusion_ext", "datasources", "decimal", + "fnv", "futures", "ioutil", "kdl", @@ -6870,6 +6877,7 @@ dependencies = [ "regex", "serde", "serde_json", + "siphasher 1.0.0", "strum", "telemetry", "thiserror", diff --git a/crates/sqlbuiltins/Cargo.toml b/crates/sqlbuiltins/Cargo.toml index 4694cdb4d..e3783815f 100644 --- a/crates/sqlbuiltins/Cargo.toml +++ b/crates/sqlbuiltins/Cargo.toml @@ -32,4 +32,6 @@ num-traits = "0.2.17" url.workspace = true strum = "0.25.0" kdl = "5.0.0-alpha.1" +siphasher = "1.0.0" +fnv = "1.0.7" diff --git a/crates/sqlbuiltins/src/functions/mod.rs b/crates/sqlbuiltins/src/functions/mod.rs index ac5f71d81..ce0ca1ffd 100644 --- a/crates/sqlbuiltins/src/functions/mod.rs +++ b/crates/sqlbuiltins/src/functions/mod.rs @@ -3,20 +3,23 @@ mod aggregates; mod scalars; mod table; -use self::scalars::df_scalars::ArrowCastFunction; -use self::scalars::kdl::{KDLMatches, KDLSelect}; -use self::scalars::{postgres::*, ConnectionId, Version}; -use self::table::{BuiltinTableFuncs, TableFunc}; +use std::collections::HashMap; +use std::sync::Arc; use datafusion::logical_expr::{AggregateFunction, BuiltinScalarFunction, Expr, Signature}; use once_cell::sync::Lazy; + +use scalars::df_scalars::ArrowCastFunction; +use scalars::hashing::{FnvHash, SipHash}; +use scalars::kdl::{KDLMatches, KDLSelect}; +use scalars::postgres::*; +use scalars::{ConnectionId, Version}; +use table::{BuiltinTableFuncs, TableFunc}; + use protogen::metastore::types::catalog::{ EntryMeta, EntryType, FunctionEntry, FunctionType, RuntimePreference, }; -use std::collections::HashMap; -use std::sync::Arc; - /// Builtin table returning functions available for all sessions. static BUILTIN_TABLE_FUNCS: Lazy = Lazy::new(BuiltinTableFuncs::new); pub static ARROW_CAST_FUNC: Lazy = Lazy::new(|| ArrowCastFunction {}); @@ -187,12 +190,15 @@ impl FunctionRegistry { Arc::new(PgTableIsVisible), Arc::new(PgEncodingToChar), Arc::new(PgArrayToString), + // System functions + Arc::new(ConnectionId), + Arc::new(Version), // KDL functions Arc::new(KDLMatches), Arc::new(KDLSelect), - // Other functions - Arc::new(ConnectionId), - Arc::new(Version), + // Hashing/Sharding + Arc::new(SipHash), + Arc::new(FnvHash), ]; let udfs = udfs .into_iter() diff --git a/crates/sqlbuiltins/src/functions/scalars/hashing.rs b/crates/sqlbuiltins/src/functions/scalars/hashing.rs new file mode 100644 index 000000000..00a1f75ef --- /dev/null +++ b/crates/sqlbuiltins/src/functions/scalars/hashing.rs @@ -0,0 +1,100 @@ +use std::hash::{Hash, Hasher}; + +use fnv::FnvHasher; +use siphasher::sip::SipHasher24; + +use super::*; + +pub struct SipHash; + +impl ConstBuiltinFunction for SipHash { + const NAME: &'static str = "siphash"; + const DESCRIPTION: &'static str = + "Calculates a 64bit non-cryptographic hash (SipHash24) of the value."; + const EXAMPLE: &'static str = "siphash()"; + const FUNCTION_TYPE: FunctionType = FunctionType::Scalar; + + fn signature(&self) -> Option { + Some(Signature::new( + // args: + TypeSignature::Any(1), + Volatility::Immutable, + )) + } +} +impl BuiltinScalarUDF for SipHash { + fn as_expr(&self, args: Vec) -> Expr { + let udf = ScalarUDF { + name: Self::NAME.to_string(), + signature: ConstBuiltinFunction::signature(self).unwrap(), + return_type: Arc::new(|_| Ok(Arc::new(DataType::Utf8))), + fun: Arc::new(move |input| match get_nth_scalar_value(input, 0) { + Some(value) => { + let mut hasher = SipHasher24::new(); + + value.hash(&mut hasher); + + Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some( + hasher.finish(), + )))) + } + None => { + return Err(datafusion::error::DataFusionError::Execution( + "must have exactly one value to hash".to_string(), + )) + } + }), + }; + Expr::ScalarUDF(datafusion::logical_expr::expr::ScalarUDF::new( + Arc::new(udf), + args, + )) + } +} + +pub struct FnvHash; + +impl ConstBuiltinFunction for FnvHash { + const NAME: &'static str = "fnv"; + const DESCRIPTION: &'static str = + "Calculates a 64bit non-cryptographic hash (fnv1a) of the value."; + const EXAMPLE: &'static str = "fnv()"; + const FUNCTION_TYPE: FunctionType = FunctionType::Scalar; + + fn signature(&self) -> Option { + Some(Signature::new( + // args: + TypeSignature::Any(1), + Volatility::Immutable, + )) + } +} +impl BuiltinScalarUDF for FnvHash { + fn as_expr(&self, args: Vec) -> Expr { + let udf = ScalarUDF { + name: Self::NAME.to_string(), + signature: ConstBuiltinFunction::signature(self).unwrap(), + return_type: Arc::new(|_| Ok(Arc::new(DataType::Utf8))), + fun: Arc::new(move |input| match get_nth_scalar_value(input, 0) { + Some(value) => { + let mut hasher = FnvHasher::default(); + + value.hash(&mut hasher); + + Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some( + hasher.finish(), + )))) + } + None => { + return Err(datafusion::error::DataFusionError::Execution( + "must have exactly one value to hash".to_string(), + )) + } + }), + }; + Expr::ScalarUDF(datafusion::logical_expr::expr::ScalarUDF::new( + Arc::new(udf), + args, + )) + } +} diff --git a/crates/sqlbuiltins/src/functions/scalars/mod.rs b/crates/sqlbuiltins/src/functions/scalars/mod.rs index 2c24c1b29..d5f13611f 100644 --- a/crates/sqlbuiltins/src/functions/scalars/mod.rs +++ b/crates/sqlbuiltins/src/functions/scalars/mod.rs @@ -1,24 +1,22 @@ pub mod df_scalars; +pub mod hashing; pub mod kdl; pub mod postgres; -use crate::{ - document, - functions::{BuiltinFunction, BuiltinScalarUDF, ConstBuiltinFunction}, -}; -use datafusion::logical_expr::BuiltinScalarFunction; - -use protogen::metastore::types::catalog::FunctionType; use std::sync::Arc; -use datafusion::{ - arrow::datatypes::{DataType, Field}, - logical_expr::{Expr, ScalarUDF, Signature, TypeSignature, Volatility}, - physical_plan::ColumnarValue, - scalar::ScalarValue, -}; +use crate::document; +use crate::functions::{BuiltinFunction, BuiltinScalarUDF, ConstBuiltinFunction}; +use datafusion::logical_expr::BuiltinScalarFunction; +use protogen::metastore::types::catalog::FunctionType; + +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::logical_expr::{Expr, ScalarUDF, Signature, TypeSignature, Volatility}; +use datafusion::physical_plan::ColumnarValue; +use datafusion::scalar::ScalarValue; pub struct ConnectionId; + impl ConstBuiltinFunction for ConnectionId { const NAME: &'static str = "connection_id"; const DESCRIPTION: &'static str = "Returns the connection id of the current session"; @@ -34,7 +32,9 @@ impl BuiltinScalarUDF for ConnectionId { session_var("connection_id") } } + pub struct Version; + impl ConstBuiltinFunction for Version { const NAME: &'static str = "version"; const DESCRIPTION: &'static str = "Returns the version of the database"; diff --git a/testdata/sqllogictests/functions/hashing.slt b/testdata/sqllogictests/functions/hashing.slt new file mode 100644 index 000000000..58ec3e0e5 --- /dev/null +++ b/testdata/sqllogictests/functions/hashing.slt @@ -0,0 +1,98 @@ +statement error +select siphash(1, 2, 3); + +statement error +select siphash(1, 2); + +statement ok +select siphash(1); + +statement ok +select siphash('000'); + +statement ok +select siphash(9001); + +statement ok +select siphash(true); + + +statement error +select fnv(1, 2, 3); + +statement error +select fnv(1, 2); + +statement ok +select fnv(1); + +statement ok +select fnv('000'); + +statement ok +select fnv(9001); + +statement ok +select fnv(true); + +query I +select siphash(); +---- +13715208377448023093 + +query I +select siphash(42); +---- +8315904219845249920 + +query I +select siphash(3000); +---- +14490819164275330428 + +query I +select siphash('42'); +---- +8771948186893062792 + +query I +select fnv(); +---- +12478008331234465636 + +query I +select fnv(42); +---- +10346157209210711374 + +query I +select fnv(3000); +---- +4500112066730064389 + +query I +select fnv('42'); +---- +16857446072837519227 + +# rerun some earlier cases to ensure we're not accidentally stateful + +query I +select fnv(); +---- +12478008331234465636 + +query I +select fnv('42'); +---- +16857446072837519227 + +query I +select siphash(); +---- +13715208377448023093 + +query I +select siphash('42'); +---- +8771948186893062792 From 5c498d1adb6b47e410a0bd2c15a34da20811820a Mon Sep 17 00:00:00 2001 From: tycho garen Date: Sat, 23 Dec 2023 09:08:10 -0500 Subject: [PATCH 02/14] add multi-value --- testdata/sqllogictests/functions/hashing.slt | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/testdata/sqllogictests/functions/hashing.slt b/testdata/sqllogictests/functions/hashing.slt index 58ec3e0e5..d66ef0012 100644 --- a/testdata/sqllogictests/functions/hashing.slt +++ b/testdata/sqllogictests/functions/hashing.slt @@ -55,6 +55,11 @@ select siphash('42'); ---- 8771948186893062792 +query I +select siphash([0, 100, 100, 300, 500, 800]); +---- +3773389192103504674 + query I select fnv(); ---- @@ -75,6 +80,11 @@ select fnv('42'); ---- 16857446072837519227 +query I +select fnv([0, 100, 100, 300, 500, 800]); +---- +16366770854503053831 + # rerun some earlier cases to ensure we're not accidentally stateful query I From b90a281c2fc078d1aa976cccb7e667aadaf9f5b1 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Sat, 23 Dec 2023 09:17:51 -0500 Subject: [PATCH 03/14] fix lint --- .../sqlbuiltins/src/functions/scalars/hashing.rs | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/crates/sqlbuiltins/src/functions/scalars/hashing.rs b/crates/sqlbuiltins/src/functions/scalars/hashing.rs index 00a1f75ef..f765328cf 100644 --- a/crates/sqlbuiltins/src/functions/scalars/hashing.rs +++ b/crates/sqlbuiltins/src/functions/scalars/hashing.rs @@ -38,11 +38,9 @@ impl BuiltinScalarUDF for SipHash { hasher.finish(), )))) } - None => { - return Err(datafusion::error::DataFusionError::Execution( - "must have exactly one value to hash".to_string(), - )) - } + None => Err(datafusion::error::DataFusionError::Execution( + "must have exactly one value to hash".to_string(), + )), }), }; Expr::ScalarUDF(datafusion::logical_expr::expr::ScalarUDF::new( @@ -85,11 +83,9 @@ impl BuiltinScalarUDF for FnvHash { hasher.finish(), )))) } - None => { - return Err(datafusion::error::DataFusionError::Execution( - "must have exactly one value to hash".to_string(), - )) - } + None => Err(datafusion::error::DataFusionError::Execution( + "must have exactly one value to hash".to_string(), + )), }), }; Expr::ScalarUDF(datafusion::logical_expr::expr::ScalarUDF::new( From 133f60ceddaad72da326d933b66f2c7392b7dccd Mon Sep 17 00:00:00 2001 From: tycho garen Date: Sat, 23 Dec 2023 23:36:47 -0500 Subject: [PATCH 04/14] fix sig --- crates/sqlbuiltins/src/functions/scalars/hashing.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/sqlbuiltins/src/functions/scalars/hashing.rs b/crates/sqlbuiltins/src/functions/scalars/hashing.rs index f765328cf..1d6bfa59c 100644 --- a/crates/sqlbuiltins/src/functions/scalars/hashing.rs +++ b/crates/sqlbuiltins/src/functions/scalars/hashing.rs @@ -27,7 +27,7 @@ impl BuiltinScalarUDF for SipHash { let udf = ScalarUDF { name: Self::NAME.to_string(), signature: ConstBuiltinFunction::signature(self).unwrap(), - return_type: Arc::new(|_| Ok(Arc::new(DataType::Utf8))), + return_type: Arc::new(|_| Ok(Arc::new(DataType::UInt64))), fun: Arc::new(move |input| match get_nth_scalar_value(input, 0) { Some(value) => { let mut hasher = SipHasher24::new(); @@ -72,7 +72,7 @@ impl BuiltinScalarUDF for FnvHash { let udf = ScalarUDF { name: Self::NAME.to_string(), signature: ConstBuiltinFunction::signature(self).unwrap(), - return_type: Arc::new(|_| Ok(Arc::new(DataType::Utf8))), + return_type: Arc::new(|_| Ok(Arc::new(DataType::UInt64))), fun: Arc::new(move |input| match get_nth_scalar_value(input, 0) { Some(value) => { let mut hasher = FnvHasher::default(); From bd3da32198de0dd5163797ff1a2d9d9bce6580ae Mon Sep 17 00:00:00 2001 From: tycho garen Date: Sun, 24 Dec 2023 10:49:52 -0500 Subject: [PATCH 05/14] fix: scalar udf function parsing --- crates/sqlbuiltins/src/errors.rs | 8 +++ .../sqlbuiltins/src/functions/scalars/mod.rs | 52 +++++++++++++------ 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/crates/sqlbuiltins/src/errors.rs b/crates/sqlbuiltins/src/errors.rs index 79a300d16..2384ff5a4 100644 --- a/crates/sqlbuiltins/src/errors.rs +++ b/crates/sqlbuiltins/src/errors.rs @@ -1,7 +1,15 @@ +use datafusion::arrow::error::ArrowError; + #[derive(Debug, thiserror::Error)] pub enum BuiltinError { #[error(transparent)] DatafusionExtError(#[from] datafusion_ext::errors::ExtensionError), + + #[error(transparent)] + DatafusionError(#[from] datafusion::error::DataFusionError), + + #[error(transparent)] + ArrowError(#[from] ArrowError), } pub type Result = std::result::Result; diff --git a/crates/sqlbuiltins/src/functions/scalars/mod.rs b/crates/sqlbuiltins/src/functions/scalars/mod.rs index 2c24c1b29..4f4965f23 100644 --- a/crates/sqlbuiltins/src/functions/scalars/mod.rs +++ b/crates/sqlbuiltins/src/functions/scalars/mod.rs @@ -1,24 +1,23 @@ pub mod df_scalars; pub mod kdl; pub mod postgres; -use crate::{ - document, - functions::{BuiltinFunction, BuiltinScalarUDF, ConstBuiltinFunction}, -}; -use datafusion::logical_expr::BuiltinScalarFunction; - -use protogen::metastore::types::catalog::FunctionType; use std::sync::Arc; -use datafusion::{ - arrow::datatypes::{DataType, Field}, - logical_expr::{Expr, ScalarUDF, Signature, TypeSignature, Volatility}, - physical_plan::ColumnarValue, - scalar::ScalarValue, -}; +use datafusion::arrow::array::{make_array, Array, ArrayDataBuilder}; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::logical_expr::BuiltinScalarFunction; +use datafusion::logical_expr::{Expr, ScalarUDF, Signature, TypeSignature, Volatility}; +use datafusion::physical_plan::ColumnarValue; +use datafusion::scalar::ScalarValue; + +use crate::document; +use crate::errors::BuiltinError; +use crate::functions::{BuiltinFunction, BuiltinScalarUDF, ConstBuiltinFunction}; +use protogen::metastore::types::catalog::FunctionType; pub struct ConnectionId; + impl ConstBuiltinFunction for ConnectionId { const NAME: &'static str = "connection_id"; const DESCRIPTION: &'static str = "Returns the connection id of the current session"; @@ -34,7 +33,9 @@ impl BuiltinScalarUDF for ConnectionId { session_var("connection_id") } } + pub struct Version; + impl ConstBuiltinFunction for Version { const NAME: &'static str = "version"; const DESCRIPTION: &'static str = "Returns the version of the database"; @@ -51,13 +52,30 @@ impl BuiltinScalarUDF for Version { } } -fn get_nth_scalar_value(input: &[ColumnarValue], n: usize) -> Option { +fn get_nth_scalar_value( + input: &[ColumnarValue], + n: usize, + op: &dyn Fn(Option) -> Result, + output_type: DataType, +) -> Result { match input.get(n) { Some(input) => match input { - ColumnarValue::Scalar(scalar) => Some(scalar.clone()), - ColumnarValue::Array(arr) => ScalarValue::try_from_array(arr, 0).ok(), + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(op(Some(scalar.clone()))?)), + ColumnarValue::Array(arr) => { + let mut builder = ArrayDataBuilder::new(output_type); + + for idx in 0..arr.len() { + builder.add_child_data( + op(Some(ScalarValue::try_from_array(arr, idx)?))? + .to_array() + .into_data(), + ); + } + + Ok(ColumnarValue::Array(make_array(builder.build()?))) + } }, - None => None, + None => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(true)))), } } From 94895c0e6cdb9bb815c86280ccfee9cd045a7a54 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Sun, 24 Dec 2023 10:58:58 -0500 Subject: [PATCH 06/14] interger upcast --- .../sqlbuiltins/src/functions/scalars/mod.rs | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/crates/sqlbuiltins/src/functions/scalars/mod.rs b/crates/sqlbuiltins/src/functions/scalars/mod.rs index 4f4965f23..c8ca7e396 100644 --- a/crates/sqlbuiltins/src/functions/scalars/mod.rs +++ b/crates/sqlbuiltins/src/functions/scalars/mod.rs @@ -6,10 +6,12 @@ use std::sync::Arc; use datafusion::arrow::array::{make_array, Array, ArrayDataBuilder}; use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::error::DataFusionError; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::logical_expr::{Expr, ScalarUDF, Signature, TypeSignature, Volatility}; use datafusion::physical_plan::ColumnarValue; use datafusion::scalar::ScalarValue; +use num_traits::ToPrimitive; use crate::document; use crate::errors::BuiltinError; @@ -79,6 +81,83 @@ fn get_nth_scalar_value( } } +fn safe_up_cast_integer_scalar( + dt: DataType, + idx: usize, + value: i64, +) -> Result { + if value < 0 { + Err(datafusion::error::DataFusionError::Execution( + format!( + "expected {} value at {} to be greater than zero or unsigned", + dt, idx, + ) + .to_string(), + )) + } else { + Ok(value as u64) + } +} + +fn get_nth_scalar_as_u64(input: &[ColumnarValue], n: usize) -> Result { + match input.get(n) { + Some(input) => match input { + ColumnarValue::Scalar(scalar) => match scalar.clone() { + ScalarValue::Int8(Some(value)) => { + safe_up_cast_integer_scalar(scalar.data_type(), n, value as i64) + } + ScalarValue::Int16(Some(value)) => { + safe_up_cast_integer_scalar(scalar.data_type(), n, value as i64) + } + ScalarValue::Int32(Some(value)) => { + safe_up_cast_integer_scalar(scalar.data_type(), n, value as i64) + } + ScalarValue::Int64(Some(value)) => { + safe_up_cast_integer_scalar(scalar.data_type(), n, value) + } + ScalarValue::UInt8(Some(value)) => Ok(value as u64), + ScalarValue::UInt16(Some(value)) => Ok(value as u64), + ScalarValue::UInt32(Some(value)) => Ok(value as u64), + ScalarValue::Float64(Some(value)) => { + if value.trunc() != value { + Err(datafusion::error::DataFusionError::Execution( + format!("float value {} at index {}, expected integer", value, n) + .to_string(), + )) + } else { + Ok(value.to_i64().unwrap() as u64) + } + } + ScalarValue::Float32(Some(value)) => { + if value.trunc() != value { + Err(datafusion::error::DataFusionError::Execution( + format!("float value {} at index {}, expected integer", value, n) + .to_string(), + )) + } else { + Ok(value.to_i64().unwrap() as u64) + } + } + ScalarValue::UInt64(Some(value)) => Ok(value), + _ => Err(datafusion::error::DataFusionError::Execution( + format!( + "value in index {} was {}, expected integer", + n, + scalar.data_type() + ) + .to_string(), + )), + }, + ColumnarValue::Array(_) => Err(datafusion::error::DataFusionError::Execution( + format!("invalid array value in index {}, expected integer", n).to_string(), + )), + }, + None => Err(datafusion::error::DataFusionError::Execution( + format!("expected integer value in index {}", n).to_string(), + )), + } +} + fn session_var(s: &str) -> Expr { Expr::ScalarVariable(DataType::Utf8, vec![s.to_string()]) } From 9afe476dd43cddb10a6648ec18514814571003dc Mon Sep 17 00:00:00 2001 From: tycho garen Date: Sun, 24 Dec 2023 12:30:19 -0500 Subject: [PATCH 07/14] mod updates --- Cargo.lock | 35 +++++- crates/sqlbuiltins/Cargo.toml | 1 + crates/sqlbuiltins/src/errors.rs | 44 ++++++-- .../sqlbuiltins/src/functions/scalars/kdl.rs | 67 +++--------- .../sqlbuiltins/src/functions/scalars/mod.rs | 102 ++++++++++++------ 5 files changed, 162 insertions(+), 87 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 80ca901c8..7078d71c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4012,6 +4012,15 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "lru" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999beba7b6e8345721bd280141ed958096a2e4abdf74f67ff4ce49b4b54e47a" +dependencies = [ + "hashbrown 0.12.3", +] + [[package]] name = "lru" version = "0.12.0" @@ -4145,6 +4154,29 @@ dependencies = [ "autocfg", ] +[[package]] +name = "memoize" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5df4051db13d0816cf23196d3baa216385ae099339f5d0645a8d9ff2305e82b8" +dependencies = [ + "lazy_static", + "lru 0.7.8", + "memoize-inner", +] + +[[package]] +name = "memoize-inner" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bfde264c318ec8c2de5c39e0ba3910fac8d1065e3b947b183ebd884b799719b" +dependencies = [ + "lazy_static", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "metastore" version = "0.7.1" @@ -4370,7 +4402,7 @@ dependencies = [ "futures-util", "keyed_priority_queue", "lazy_static", - "lru", + "lru 0.12.0", "mio", "mysql_common", "once_cell", @@ -6865,6 +6897,7 @@ dependencies = [ "ioutil", "kdl", "logutil", + "memoize", "num-traits", "object_store", "once_cell", diff --git a/crates/sqlbuiltins/Cargo.toml b/crates/sqlbuiltins/Cargo.toml index 4b254d2da..f9d61efe3 100644 --- a/crates/sqlbuiltins/Cargo.toml +++ b/crates/sqlbuiltins/Cargo.toml @@ -35,3 +35,4 @@ bson = "2.7.0" tokio-util = "0.7.10" bytes = "1.5.0" kdl = "5.0.0-alpha.1" +memoize = { version = "0.4.2", features = ["full"] } diff --git a/crates/sqlbuiltins/src/errors.rs b/crates/sqlbuiltins/src/errors.rs index 2384ff5a4..0a9fb5129 100644 --- a/crates/sqlbuiltins/src/errors.rs +++ b/crates/sqlbuiltins/src/errors.rs @@ -1,15 +1,47 @@ +use datafusion::arrow::datatypes::DataType; use datafusion::arrow::error::ArrowError; +use datafusion::error::DataFusionError; -#[derive(Debug, thiserror::Error)] +#[derive(Clone, Debug, thiserror::Error)] pub enum BuiltinError { - #[error(transparent)] - DatafusionExtError(#[from] datafusion_ext::errors::ExtensionError), + #[error("parse error: {0}")] + ParseError(String), - #[error(transparent)] - DatafusionError(#[from] datafusion::error::DataFusionError), + #[error("missing value at index {0}")] + MissingValueAtIndex(usize), + + #[error("invalid value at index {0}")] + InvalidValueAtIndex(usize, String), + + #[error("value at index {0} was {1}, expected {2}")] + IncorrectTypeAtIndex(usize, DataType, DataType), #[error(transparent)] - ArrowError(#[from] ArrowError), + KdlError(#[from] kdl::KdlError), + + #[error("DataFusionError: {0}")] + DataFusionError(String), + + #[error("ArrowError: {0}")] + ArrowError(String), } pub type Result = std::result::Result; + +impl From for DataFusionError { + fn from(e: BuiltinError) -> Self { + DataFusionError::Execution(e.to_string()) + } +} + +impl From for BuiltinError { + fn from(e: DataFusionError) -> Self { + BuiltinError::DataFusionError(e.to_string()) + } +} + +impl From for BuiltinError { + fn from(e: ArrowError) -> Self { + BuiltinError::ArrowError(e.to_string()) + } +} diff --git a/crates/sqlbuiltins/src/functions/scalars/kdl.rs b/crates/sqlbuiltins/src/functions/scalars/kdl.rs index 77447b31d..3c8112c89 100644 --- a/crates/sqlbuiltins/src/functions/scalars/kdl.rs +++ b/crates/sqlbuiltins/src/functions/scalars/kdl.rs @@ -1,7 +1,8 @@ -use super::*; use ::kdl::{KdlDocument, KdlNode, KdlQuery}; +use memoize::memoize; + +use super::*; -#[derive(Clone)] pub struct KDLSelect; impl ConstBuiltinFunction for KDLSelect { @@ -61,8 +62,8 @@ impl BuiltinScalarUDF for KDLSelect { } } -#[derive(Clone)] pub struct KDLMatches; + impl ConstBuiltinFunction for KDLMatches { const NAME: &'static str = "kdl_matches"; const DESCRIPTION: &'static str = @@ -83,19 +84,12 @@ impl ConstBuiltinFunction for KDLMatches { )) } } + impl BuiltinScalarUDF for KDLMatches { fn as_expr(&self, args: Vec) -> Expr { let udf = ScalarUDF { - name: "kdl_matches".to_string(), - signature: Signature::new( - TypeSignature::OneOf(vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), - ]), - Volatility::Immutable, - ), + name: Self::NAME.to_string(), + signature: ConstBuiltinFunction::signature(self).unwrap(), return_type: Arc::new(|_| Ok(Arc::new(DataType::Boolean))), fun: Arc::new(move |input| { let (doc, filter) = kdl_parse_udf_args(input)?; @@ -120,43 +114,16 @@ fn kdl_parse_udf_args( // parse the filter first, because it's probably shorter and // erroring earlier would be preferable to parsing a large that we // don't need/want. - let filter: kdl::KdlQuery = match get_nth_scalar_value(args, 1) { - Some(ScalarValue::Utf8(Some(val))) | Some(ScalarValue::LargeUtf8(Some(val))) => { - val.parse().map_err(|err: ::kdl::KdlError| { - datafusion::common::DataFusionError::Execution(err.to_string()) - })? - } - Some(val) => { - return Err(datafusion::common::DataFusionError::Execution(format!( - "invalid type for KQL expression {}", - val.data_type(), - ))) - } - None => { - return Err(datafusion::common::DataFusionError::Execution( - "missing KQL query".to_string(), - )) - } - }; - - let doc: kdl::KdlDocument = match get_nth_scalar_value(args, 0) { - Some(ScalarValue::Utf8(Some(val))) | Some(ScalarValue::LargeUtf8(Some(val))) => { - val.parse().map_err(|err: ::kdl::KdlError| { - datafusion::common::DataFusionError::Execution(err.to_string()) - })? - } - Some(val) => { - return Err(datafusion::common::DataFusionError::Execution(format!( - "invalid type for KDL value {}", - val.data_type(), - ))) - } - None => { - return Err(datafusion::common::DataFusionError::Execution( - "invalid field for KDL".to_string(), - )) - } - }; + let filter = compile_kdl_query(get_nth_string_value(args, 1)?)?; + + let doc: kdl::KdlDocument = get_nth_string_value(args, 0)? + .parse() + .map_err(BuiltinError::KdlError)?; Ok((doc, filter)) } + +#[memoize(Capacity: 256)] +fn compile_kdl_query(query: String) -> Result { + query.parse().map_err(BuiltinError::KdlError) +} diff --git a/crates/sqlbuiltins/src/functions/scalars/mod.rs b/crates/sqlbuiltins/src/functions/scalars/mod.rs index c8ca7e396..9b20e4ed7 100644 --- a/crates/sqlbuiltins/src/functions/scalars/mod.rs +++ b/crates/sqlbuiltins/src/functions/scalars/mod.rs @@ -99,10 +99,22 @@ fn safe_up_cast_integer_scalar( } } -fn get_nth_scalar_as_u64(input: &[ColumnarValue], n: usize) -> Result { - match input.get(n) { - Some(input) => match input { - ColumnarValue::Scalar(scalar) => match scalar.clone() { +fn get_nth_scalar_as_u64( + input: &[ColumnarValue], + n: usize, + op: &dyn Fn(Option) -> Result, + output_type: DataType, +) -> Result { + get_nth_scalar_value( + input, + n, + &|scalar| -> Result { + let scalar = match scalar.clone() { + Some(v) => v.to_owned(), + None => return Err(BuiltinError::MissingValueAtIndex(n)), + }; + + let value = match scalar { ScalarValue::Int8(Some(value)) => { safe_up_cast_integer_scalar(scalar.data_type(), n, value as i64) } @@ -120,42 +132,72 @@ fn get_nth_scalar_as_u64(input: &[ColumnarValue], n: usize) -> Result Ok(value as u64), ScalarValue::Float64(Some(value)) => { if value.trunc() != value { - Err(datafusion::error::DataFusionError::Execution( - format!("float value {} at index {}, expected integer", value, n) - .to_string(), - )) - } else { - Ok(value.to_i64().unwrap() as u64) + return Err(BuiltinError::InvalidValueAtIndex( + n, + format!("expected whole value for float {}", value).to_string(), + )); } + Ok(value.to_i64().ok_or(BuiltinError::IncorrectTypeAtIndex( + n, + scalar.data_type(), + DataType::UInt64, + ))? as u64) } ScalarValue::Float32(Some(value)) => { if value.trunc() != value { - Err(datafusion::error::DataFusionError::Execution( - format!("float value {} at index {}, expected integer", value, n) - .to_string(), - )) - } else { - Ok(value.to_i64().unwrap() as u64) + return Err(BuiltinError::InvalidValueAtIndex( + n, + format!("expected whole value for float {}", value).to_string(), + )); } + Ok(value.to_i64().ok_or(BuiltinError::IncorrectTypeAtIndex( + n, + scalar.data_type(), + DataType::UInt64, + ))? as u64) } ScalarValue::UInt64(Some(value)) => Ok(value), - _ => Err(datafusion::error::DataFusionError::Execution( - format!( - "value in index {} was {}, expected integer", + _ => { + return Err(BuiltinError::IncorrectTypeAtIndex( n, - scalar.data_type() - ) - .to_string(), + scalar.data_type(), + DataType::UInt64, + )) + } + }?; + + op(Some(value)) + }, + output_type, + ) +} + +fn get_nth_string_value( + input: &[ColumnarValue], + n: usize, + op: &dyn Fn(Option) -> Result, + output_type: DataType, +) -> Result { + get_nth_scalar_value( + input, + n, + &|scalar| -> Result { + let scalar = match scalar.clone() { + Some(v) => v.to_owned(), + None => return Err(BuiltinError::MissingValueAtIndex(n)), + }; + + match scalar { + ScalarValue::Utf8(v) | ScalarValue::LargeUtf8(v) => op(v), + _ => Err(BuiltinError::IncorrectTypeAtIndex( + n, + scalar.data_type(), + DataType::Utf8, )), - }, - ColumnarValue::Array(_) => Err(datafusion::error::DataFusionError::Execution( - format!("invalid array value in index {}, expected integer", n).to_string(), - )), + } }, - None => Err(datafusion::error::DataFusionError::Execution( - format!("expected integer value in index {}", n).to_string(), - )), - } + output_type, + ) } fn session_var(s: &str) -> Expr { From 50d4bd877ad6c8c52e00535f9cd2a7ca45b2a228 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Sun, 24 Dec 2023 17:59:56 -0500 Subject: [PATCH 08/14] cleanup complete --- crates/sqlbuiltins/src/errors.rs | 6 + .../sqlbuiltins/src/functions/scalars/kdl.rs | 93 +++++---- .../sqlbuiltins/src/functions/scalars/mod.rs | 189 +++++++++--------- .../src/functions/scalars/postgres.rs | 36 ++-- 4 files changed, 178 insertions(+), 146 deletions(-) diff --git a/crates/sqlbuiltins/src/errors.rs b/crates/sqlbuiltins/src/errors.rs index 0a9fb5129..b1b6e2041 100644 --- a/crates/sqlbuiltins/src/errors.rs +++ b/crates/sqlbuiltins/src/errors.rs @@ -7,12 +7,18 @@ pub enum BuiltinError { #[error("parse error: {0}")] ParseError(String), + #[error("fundamental parsing error")] + FundamentalError, + #[error("missing value at index {0}")] MissingValueAtIndex(usize), #[error("invalid value at index {0}")] InvalidValueAtIndex(usize, String), + #[error("columnar values not support at index {0}")] + InvalidColumnarValue(usize), + #[error("value at index {0} was {1}, expected {2}")] IncorrectTypeAtIndex(usize, DataType, DataType), diff --git a/crates/sqlbuiltins/src/functions/scalars/kdl.rs b/crates/sqlbuiltins/src/functions/scalars/kdl.rs index 3c8112c89..63f15e0d2 100644 --- a/crates/sqlbuiltins/src/functions/scalars/kdl.rs +++ b/crates/sqlbuiltins/src/functions/scalars/kdl.rs @@ -32,27 +32,37 @@ impl BuiltinScalarUDF for KDLSelect { signature: ConstBuiltinFunction::signature(self).unwrap(), return_type: Arc::new(|_| Ok(Arc::new(DataType::Utf8))), fun: Arc::new(move |input| { - let (sdoc, filter) = kdl_parse_udf_args(input)?; - - let out: Vec<&KdlNode> = sdoc - .query_all(filter) - .map_err(|e| datafusion::common::DataFusionError::Execution(e.to_string())) - .map(|iter| iter.collect())?; - - let mut doc = sdoc.clone(); - let elems = doc.nodes_mut(); - elems.clear(); - for item in &out { - elems.push(item.to_owned().clone()) - } - - // TODO: consider if we should always return LargeUtf8? - // could end up with truncation (or an error) the document - // is too long and we write the data to a table that is - // established (and mostly) shorter values. - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( - doc.to_string(), - )))) + let filter = get_nth_string_fn_arg(input, 1)?; + + get_nth_string_value( + input, + 0, + &|value: Option| -> Result { + let sdoc: kdl::KdlDocument = value + .ok_or(BuiltinError::MissingValueAtIndex(0))? + .parse() + .map_err(BuiltinError::KdlError)?; + + let out: Vec<&KdlNode> = sdoc + .query_all(compile_kdl_query(filter.clone())?) + .map_err(BuiltinError::KdlError) + .map(|iter| iter.collect())?; + + let mut doc = sdoc.clone(); + let elems = doc.nodes_mut(); + elems.clear(); + for item in &out { + elems.push(item.to_owned().clone()) + } + + // TODO: consider if we should always return LargeUtf8? + // could end up with truncation (or an error) the document + // is too long and we write the data to a table that is + // established (and mostly) shorter values. + Ok(ScalarValue::Utf8(Some(doc.to_string()))) + }, + ) + .map_err(DataFusionError::from) }), }; Expr::ScalarUDF(datafusion::logical_expr::expr::ScalarUDF::new( @@ -92,13 +102,25 @@ impl BuiltinScalarUDF for KDLMatches { signature: ConstBuiltinFunction::signature(self).unwrap(), return_type: Arc::new(|_| Ok(Arc::new(DataType::Boolean))), fun: Arc::new(move |input| { - let (doc, filter) = kdl_parse_udf_args(input)?; - - Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( - doc.query(filter) - .map_err(|e| datafusion::common::DataFusionError::Execution(e.to_string())) - .map(|val| val.is_some())?, - )))) + let filter = get_nth_string_fn_arg(input, 1)?; + + get_nth_string_value( + input, + 0, + &|value: Option| -> Result { + let doc: kdl::KdlDocument = value + .ok_or(BuiltinError::MissingValueAtIndex(0))? + .parse() + .map_err(BuiltinError::KdlError)?; + + Ok(ScalarValue::Boolean(Some( + doc.query(compile_kdl_query(filter.clone())?) + .map(|v| v.is_some()) + .map_err(BuiltinError::KdlError)?, + ))) + }, + ) + .map_err(DataFusionError::from) }), }; Expr::ScalarUDF(datafusion::logical_expr::expr::ScalarUDF::new( @@ -108,21 +130,6 @@ impl BuiltinScalarUDF for KDLMatches { } } -fn kdl_parse_udf_args( - args: &[ColumnarValue], -) -> datafusion::error::Result<(KdlDocument, KdlQuery)> { - // parse the filter first, because it's probably shorter and - // erroring earlier would be preferable to parsing a large that we - // don't need/want. - let filter = compile_kdl_query(get_nth_string_value(args, 1)?)?; - - let doc: kdl::KdlDocument = get_nth_string_value(args, 0)? - .parse() - .map_err(BuiltinError::KdlError)?; - - Ok((doc, filter)) -} - #[memoize(Capacity: 256)] fn compile_kdl_query(query: String) -> Result { query.parse().map_err(BuiltinError::KdlError) diff --git a/crates/sqlbuiltins/src/functions/scalars/mod.rs b/crates/sqlbuiltins/src/functions/scalars/mod.rs index 9b20e4ed7..e31203411 100644 --- a/crates/sqlbuiltins/src/functions/scalars/mod.rs +++ b/crates/sqlbuiltins/src/functions/scalars/mod.rs @@ -4,7 +4,7 @@ pub mod postgres; use std::sync::Arc; -use datafusion::arrow::array::{make_array, Array, ArrayDataBuilder}; +use datafusion::arrow::array::Array; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::error::DataFusionError; use datafusion::logical_expr::BuiltinScalarFunction; @@ -58,23 +58,22 @@ fn get_nth_scalar_value( input: &[ColumnarValue], n: usize, op: &dyn Fn(Option) -> Result, - output_type: DataType, ) -> Result { match input.get(n) { Some(input) => match input { ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(op(Some(scalar.clone()))?)), ColumnarValue::Array(arr) => { - let mut builder = ArrayDataBuilder::new(output_type); + let mut values = Vec::with_capacity(arr.len()); for idx in 0..arr.len() { - builder.add_child_data( - op(Some(ScalarValue::try_from_array(arr, idx)?))? - .to_array() - .into_data(), - ); + let value = ScalarValue::try_from_array(arr, idx)?; + let value = op(Some(value))?; + values.push(value); } - Ok(ColumnarValue::Array(make_array(builder.build()?))) + Ok(ColumnarValue::Array(ScalarValue::iter_to_array( + values.into_iter(), + )?)) } }, None => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(true)))), @@ -99,105 +98,115 @@ fn safe_up_cast_integer_scalar( } } +#[allow(dead_code)] // will get removed before this hits mainline; stacked commit issue fn get_nth_scalar_as_u64( input: &[ColumnarValue], n: usize, op: &dyn Fn(Option) -> Result, - output_type: DataType, ) -> Result { - get_nth_scalar_value( - input, - n, - &|scalar| -> Result { - let scalar = match scalar.clone() { - Some(v) => v.to_owned(), - None => return Err(BuiltinError::MissingValueAtIndex(n)), - }; - - let value = match scalar { - ScalarValue::Int8(Some(value)) => { - safe_up_cast_integer_scalar(scalar.data_type(), n, value as i64) - } - ScalarValue::Int16(Some(value)) => { - safe_up_cast_integer_scalar(scalar.data_type(), n, value as i64) - } - ScalarValue::Int32(Some(value)) => { - safe_up_cast_integer_scalar(scalar.data_type(), n, value as i64) - } - ScalarValue::Int64(Some(value)) => { - safe_up_cast_integer_scalar(scalar.data_type(), n, value) - } - ScalarValue::UInt8(Some(value)) => Ok(value as u64), - ScalarValue::UInt16(Some(value)) => Ok(value as u64), - ScalarValue::UInt32(Some(value)) => Ok(value as u64), - ScalarValue::Float64(Some(value)) => { - if value.trunc() != value { - return Err(BuiltinError::InvalidValueAtIndex( - n, - format!("expected whole value for float {}", value).to_string(), - )); - } - Ok(value.to_i64().ok_or(BuiltinError::IncorrectTypeAtIndex( - n, - scalar.data_type(), - DataType::UInt64, - ))? as u64) - } - ScalarValue::Float32(Some(value)) => { - if value.trunc() != value { - return Err(BuiltinError::InvalidValueAtIndex( - n, - format!("expected whole value for float {}", value).to_string(), - )); - } - Ok(value.to_i64().ok_or(BuiltinError::IncorrectTypeAtIndex( + get_nth_scalar_value(input, n, &|scalar| -> Result { + let scalar = match scalar.clone() { + Some(v) => v.to_owned(), + None => return Err(BuiltinError::MissingValueAtIndex(n)), + }; + + let value = match scalar { + ScalarValue::Int8(Some(value)) => { + safe_up_cast_integer_scalar(scalar.data_type(), n, value as i64) + } + ScalarValue::Int16(Some(value)) => { + safe_up_cast_integer_scalar(scalar.data_type(), n, value as i64) + } + ScalarValue::Int32(Some(value)) => { + safe_up_cast_integer_scalar(scalar.data_type(), n, value as i64) + } + ScalarValue::Int64(Some(value)) => { + safe_up_cast_integer_scalar(scalar.data_type(), n, value) + } + ScalarValue::UInt8(Some(value)) => Ok(value as u64), + ScalarValue::UInt16(Some(value)) => Ok(value as u64), + ScalarValue::UInt32(Some(value)) => Ok(value as u64), + ScalarValue::Float64(Some(value)) => { + if value.trunc() != value { + return Err(BuiltinError::InvalidValueAtIndex( n, - scalar.data_type(), - DataType::UInt64, - ))? as u64) + format!("expected whole value for float {}", value).to_string(), + )); } - ScalarValue::UInt64(Some(value)) => Ok(value), - _ => { - return Err(BuiltinError::IncorrectTypeAtIndex( + Ok(value.to_i64().ok_or(BuiltinError::IncorrectTypeAtIndex( + n, + scalar.data_type(), + DataType::UInt64, + ))? as u64) + } + ScalarValue::Float32(Some(value)) => { + if value.trunc() != value { + return Err(BuiltinError::InvalidValueAtIndex( n, - scalar.data_type(), - DataType::UInt64, - )) + format!("expected whole value for float {}", value).to_string(), + )); } - }?; + Ok(value.to_i64().ok_or(BuiltinError::IncorrectTypeAtIndex( + n, + scalar.data_type(), + DataType::UInt64, + ))? as u64) + } + ScalarValue::UInt64(Some(value)) => Ok(value), + _ => { + return Err(BuiltinError::IncorrectTypeAtIndex( + n, + scalar.data_type(), + DataType::UInt64, + )) + } + }?; + + op(Some(value)) + }) +} - op(Some(value)) +// get_nth_string_fn_arg extracts a string value (or tries to) from a +// function argument; columns are always an error. +fn get_nth_string_fn_arg(input: &[ColumnarValue], idx: usize) -> Result { + match input.get(idx) { + Some(input) => match input { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) => Ok(v.to_owned()), + ColumnarValue::Array(_) => Err(BuiltinError::InvalidColumnarValue(idx)), + _ => Err(BuiltinError::IncorrectTypeAtIndex( + idx, + input.data_type(), + DataType::Utf8, + )), }, - output_type, - ) + None => Err(BuiltinError::MissingValueAtIndex(idx)), + } } +// get_nth_string_value processes a function argument that is expected +// to be a string, as a helper for a common case around +// get_nth_scalar_value. fn get_nth_string_value( input: &[ColumnarValue], n: usize, op: &dyn Fn(Option) -> Result, - output_type: DataType, ) -> Result { - get_nth_scalar_value( - input, - n, - &|scalar| -> Result { - let scalar = match scalar.clone() { - Some(v) => v.to_owned(), - None => return Err(BuiltinError::MissingValueAtIndex(n)), - }; - - match scalar { - ScalarValue::Utf8(v) | ScalarValue::LargeUtf8(v) => op(v), - _ => Err(BuiltinError::IncorrectTypeAtIndex( - n, - scalar.data_type(), - DataType::Utf8, - )), - } - }, - output_type, - ) + get_nth_scalar_value(input, n, &|scalar| -> Result { + let scalar = match scalar.clone() { + Some(v) => v.to_owned(), + None => return Err(BuiltinError::MissingValueAtIndex(n)), + }; + + match scalar { + ScalarValue::Utf8(v) | ScalarValue::LargeUtf8(v) => op(v), + _ => Err(BuiltinError::IncorrectTypeAtIndex( + n, + scalar.data_type(), + DataType::Utf8, + )), + } + }) } fn session_var(s: &str) -> Expr { diff --git a/crates/sqlbuiltins/src/functions/scalars/postgres.rs b/crates/sqlbuiltins/src/functions/scalars/postgres.rs index b8fcd91bb..dc641ea8d 100644 --- a/crates/sqlbuiltins/src/functions/scalars/postgres.rs +++ b/crates/sqlbuiltins/src/functions/scalars/postgres.rs @@ -62,12 +62,15 @@ impl BuiltinScalarUDF for PgTableIsVisible { signature: ConstBuiltinFunction::signature(self).unwrap(), return_type: Arc::new(|_| Ok(Arc::new(DataType::Boolean))), fun: Arc::new(move |input| { - let is_visible = match get_nth_scalar_value(input, 0) { - Some(ScalarValue::Int64(Some(_))) => Some(true), - _ => None, - }; - - Ok(ColumnarValue::Scalar(ScalarValue::Boolean(is_visible))) + Ok(get_nth_scalar_value(input, 0, &|value| -> Result< + ScalarValue, + BuiltinError, + > { + match value { + Some(ScalarValue::Int64(Some(_))) => Ok(ScalarValue::Boolean(Some(true))), + _ => Ok(ScalarValue::Boolean(None)), + } + })?) }), }; @@ -103,13 +106,20 @@ impl BuiltinScalarUDF for PgEncodingToChar { signature: ConstBuiltinFunction::signature(self).unwrap(), return_type: Arc::new(|_| Ok(Arc::new(DataType::Utf8))), fun: Arc::new(move |input| { - let enc = match get_nth_scalar_value(input, 0) { - Some(ScalarValue::Int64(Some(6))) => Some("UTF8".to_string()), - Some(ScalarValue::Int64(Some(_))) => Some("".to_string()), - _ => None, - }; - - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(enc))) + Ok(get_nth_scalar_value(input, 0, &|value| -> Result< + ScalarValue, + BuiltinError, + > { + match value { + Some(ScalarValue::Int64(Some(6))) => { + Ok(ScalarValue::Utf8(Some("UTF8".to_string()))) + } + Some(ScalarValue::Int64(Some(_))) => { + Ok(ScalarValue::Utf8(Some("".to_string()))) + } + _ => Ok(ScalarValue::Utf8(None)), + } + })?) }), }; Expr::ScalarUDF(datafusion::logical_expr::expr::ScalarUDF::new( From 3ce73d8853db2289637c7c07b23b91f7672fafbb Mon Sep 17 00:00:00 2001 From: tycho garen Date: Sun, 24 Dec 2023 20:37:09 -0500 Subject: [PATCH 09/14] backport --- crates/sqlbuiltins/src/errors.rs | 11 +- .../sqlbuiltins/src/functions/scalars/kdl.rs | 16 +- .../sqlbuiltins/src/functions/scalars/mod.rs | 152 +++++++----------- .../src/functions/scalars/postgres.rs | 8 +- 4 files changed, 75 insertions(+), 112 deletions(-) diff --git a/crates/sqlbuiltins/src/errors.rs b/crates/sqlbuiltins/src/errors.rs index b1b6e2041..6f8b7a02e 100644 --- a/crates/sqlbuiltins/src/errors.rs +++ b/crates/sqlbuiltins/src/errors.rs @@ -13,14 +13,17 @@ pub enum BuiltinError { #[error("missing value at index {0}")] MissingValueAtIndex(usize), - #[error("invalid value at index {0}")] - InvalidValueAtIndex(usize, String), + #[error("expected value missing")] + MissingValue, + + #[error("invalid value: {0}")] + InvalidValue(String), #[error("columnar values not support at index {0}")] InvalidColumnarValue(usize), - #[error("value at index {0} was {1}, expected {2}")] - IncorrectTypeAtIndex(usize, DataType, DataType), + #[error("value was type {0}, expected {1}")] + IncorrectType(DataType, DataType), #[error(transparent)] KdlError(#[from] kdl::KdlError), diff --git a/crates/sqlbuiltins/src/functions/scalars/kdl.rs b/crates/sqlbuiltins/src/functions/scalars/kdl.rs index 63f15e0d2..9e4483f78 100644 --- a/crates/sqlbuiltins/src/functions/scalars/kdl.rs +++ b/crates/sqlbuiltins/src/functions/scalars/kdl.rs @@ -37,11 +37,9 @@ impl BuiltinScalarUDF for KDLSelect { get_nth_string_value( input, 0, - &|value: Option| -> Result { - let sdoc: kdl::KdlDocument = value - .ok_or(BuiltinError::MissingValueAtIndex(0))? - .parse() - .map_err(BuiltinError::KdlError)?; + &|value: String| -> Result { + let sdoc: kdl::KdlDocument = + value.parse().map_err(BuiltinError::KdlError)?; let out: Vec<&KdlNode> = sdoc .query_all(compile_kdl_query(filter.clone())?) @@ -107,11 +105,9 @@ impl BuiltinScalarUDF for KDLMatches { get_nth_string_value( input, 0, - &|value: Option| -> Result { - let doc: kdl::KdlDocument = value - .ok_or(BuiltinError::MissingValueAtIndex(0))? - .parse() - .map_err(BuiltinError::KdlError)?; + &|value: String| -> Result { + let doc: kdl::KdlDocument = + value.parse().map_err(BuiltinError::KdlError)?; Ok(ScalarValue::Boolean(Some( doc.query(compile_kdl_query(filter.clone())?) diff --git a/crates/sqlbuiltins/src/functions/scalars/mod.rs b/crates/sqlbuiltins/src/functions/scalars/mod.rs index e31203411..669f1822b 100644 --- a/crates/sqlbuiltins/src/functions/scalars/mod.rs +++ b/crates/sqlbuiltins/src/functions/scalars/mod.rs @@ -57,18 +57,16 @@ impl BuiltinScalarUDF for Version { fn get_nth_scalar_value( input: &[ColumnarValue], n: usize, - op: &dyn Fn(Option) -> Result, + op: &dyn Fn(ScalarValue) -> Result, ) -> Result { match input.get(n) { Some(input) => match input { - ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(op(Some(scalar.clone()))?)), + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(op(scalar.clone())?)), ColumnarValue::Array(arr) => { let mut values = Vec::with_capacity(arr.len()); for idx in 0..arr.len() { - let value = ScalarValue::try_from_array(arr, idx)?; - let value = op(Some(value))?; - values.push(value); + values.push(op(ScalarValue::try_from_array(arr, idx)?)?); } Ok(ColumnarValue::Array(ScalarValue::iter_to_array( @@ -80,90 +78,65 @@ fn get_nth_scalar_value( } } -fn safe_up_cast_integer_scalar( - dt: DataType, - idx: usize, - value: i64, -) -> Result { +fn try_from_u64_scalar(scalar: ScalarValue) -> Result { + match scalar { + ScalarValue::Int8(Some(value)) => safe_up_cast_integer_scalar(value as i64), + ScalarValue::Int16(Some(value)) => safe_up_cast_integer_scalar(value as i64), + ScalarValue::Int32(Some(value)) => safe_up_cast_integer_scalar(value as i64), + ScalarValue::Int64(Some(value)) => safe_up_cast_integer_scalar(value), + ScalarValue::UInt8(Some(value)) => Ok(value as u64), + ScalarValue::UInt16(Some(value)) => Ok(value as u64), + ScalarValue::UInt32(Some(value)) => Ok(value as u64), + ScalarValue::Float64(Some(value)) => { + if value.trunc() != value { + return Err(BuiltinError::ParseError( + format!("expected whole value for float {}", value).to_string(), + )); + } + Ok(value.to_i64().ok_or(BuiltinError::IncorrectType( + scalar.data_type(), + DataType::UInt64, + ))? as u64) + } + ScalarValue::Float32(Some(value)) => { + if value.trunc() != value { + return Err(BuiltinError::InvalidValue( + format!("expected whole value for float {}", value).to_string(), + )); + } + Ok(value.to_i64().ok_or(BuiltinError::IncorrectType( + scalar.data_type(), + DataType::UInt64, + ))? as u64) + } + ScalarValue::UInt64(Some(value)) => Ok(value), + _ => { + return Err(BuiltinError::IncorrectType( + scalar.data_type(), + DataType::UInt64, + )) + } + } +} + +fn safe_up_cast_integer_scalar(value: i64) -> Result { if value < 0 { - Err(datafusion::error::DataFusionError::Execution( - format!( - "expected {} value at {} to be greater than zero or unsigned", - dt, idx, - ) - .to_string(), + Err(BuiltinError::ParseError( + format!("{} cannot be a uint64", value).to_string(), )) } else { Ok(value as u64) } } -#[allow(dead_code)] // will get removed before this hits mainline; stacked commit issue -fn get_nth_scalar_as_u64( - input: &[ColumnarValue], - n: usize, - op: &dyn Fn(Option) -> Result, -) -> Result { - get_nth_scalar_value(input, n, &|scalar| -> Result { - let scalar = match scalar.clone() { - Some(v) => v.to_owned(), - None => return Err(BuiltinError::MissingValueAtIndex(n)), - }; - - let value = match scalar { - ScalarValue::Int8(Some(value)) => { - safe_up_cast_integer_scalar(scalar.data_type(), n, value as i64) - } - ScalarValue::Int16(Some(value)) => { - safe_up_cast_integer_scalar(scalar.data_type(), n, value as i64) - } - ScalarValue::Int32(Some(value)) => { - safe_up_cast_integer_scalar(scalar.data_type(), n, value as i64) - } - ScalarValue::Int64(Some(value)) => { - safe_up_cast_integer_scalar(scalar.data_type(), n, value) - } - ScalarValue::UInt8(Some(value)) => Ok(value as u64), - ScalarValue::UInt16(Some(value)) => Ok(value as u64), - ScalarValue::UInt32(Some(value)) => Ok(value as u64), - ScalarValue::Float64(Some(value)) => { - if value.trunc() != value { - return Err(BuiltinError::InvalidValueAtIndex( - n, - format!("expected whole value for float {}", value).to_string(), - )); - } - Ok(value.to_i64().ok_or(BuiltinError::IncorrectTypeAtIndex( - n, - scalar.data_type(), - DataType::UInt64, - ))? as u64) - } - ScalarValue::Float32(Some(value)) => { - if value.trunc() != value { - return Err(BuiltinError::InvalidValueAtIndex( - n, - format!("expected whole value for float {}", value).to_string(), - )); - } - Ok(value.to_i64().ok_or(BuiltinError::IncorrectTypeAtIndex( - n, - scalar.data_type(), - DataType::UInt64, - ))? as u64) - } - ScalarValue::UInt64(Some(value)) => Ok(value), - _ => { - return Err(BuiltinError::IncorrectTypeAtIndex( - n, - scalar.data_type(), - DataType::UInt64, - )) - } - }?; - - op(Some(value)) - }) +// get_nth_64_fn_arg extracts a string value (or tries to) from a +// function argument; columns are always an error. +fn get_nth_u64_fn_arg(input: &[ColumnarValue], idx: usize) -> Result { + match input.get(idx) { + Some(ColumnarValue::Scalar(value)) => try_from_u64_scalar(value.to_owned()), + Some(ColumnarValue::Array(_)) => Err(BuiltinError::InvalidColumnarValue(idx)), + None => Err(BuiltinError::MissingValueAtIndex(idx)), + } } // get_nth_string_fn_arg extracts a string value (or tries to) from a @@ -174,8 +147,7 @@ fn get_nth_string_fn_arg(input: &[ColumnarValue], idx: usize) -> Result Ok(v.to_owned()), ColumnarValue::Array(_) => Err(BuiltinError::InvalidColumnarValue(idx)), - _ => Err(BuiltinError::IncorrectTypeAtIndex( - idx, + _ => Err(BuiltinError::IncorrectType( input.data_type(), DataType::Utf8, )), @@ -190,18 +162,12 @@ fn get_nth_string_fn_arg(input: &[ColumnarValue], idx: usize) -> Result) -> Result, + op: &dyn Fn(String) -> Result, ) -> Result { get_nth_scalar_value(input, n, &|scalar| -> Result { - let scalar = match scalar.clone() { - Some(v) => v.to_owned(), - None => return Err(BuiltinError::MissingValueAtIndex(n)), - }; - match scalar { - ScalarValue::Utf8(v) | ScalarValue::LargeUtf8(v) => op(v), - _ => Err(BuiltinError::IncorrectTypeAtIndex( - n, + ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => op(v), + _ => Err(BuiltinError::IncorrectType( scalar.data_type(), DataType::Utf8, )), diff --git a/crates/sqlbuiltins/src/functions/scalars/postgres.rs b/crates/sqlbuiltins/src/functions/scalars/postgres.rs index dc641ea8d..b80a30aee 100644 --- a/crates/sqlbuiltins/src/functions/scalars/postgres.rs +++ b/crates/sqlbuiltins/src/functions/scalars/postgres.rs @@ -67,7 +67,7 @@ impl BuiltinScalarUDF for PgTableIsVisible { BuiltinError, > { match value { - Some(ScalarValue::Int64(Some(_))) => Ok(ScalarValue::Boolean(Some(true))), + ScalarValue::Int64(Some(_)) => Ok(ScalarValue::Boolean(Some(true))), _ => Ok(ScalarValue::Boolean(None)), } })?) @@ -111,12 +111,10 @@ impl BuiltinScalarUDF for PgEncodingToChar { BuiltinError, > { match value { - Some(ScalarValue::Int64(Some(6))) => { + ScalarValue::Int64(Some(6)) => { Ok(ScalarValue::Utf8(Some("UTF8".to_string()))) } - Some(ScalarValue::Int64(Some(_))) => { - Ok(ScalarValue::Utf8(Some("".to_string()))) - } + ScalarValue::Int64(Some(_)) => Ok(ScalarValue::Utf8(Some("".to_string()))), _ => Ok(ScalarValue::Utf8(None)), } })?) From 1f59ad03bf518204b77492f2d138cfc16f9583d0 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Sun, 24 Dec 2023 21:24:46 -0500 Subject: [PATCH 10/14] fix lint --- crates/sqlbuiltins/src/functions/scalars/mod.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/sqlbuiltins/src/functions/scalars/mod.rs b/crates/sqlbuiltins/src/functions/scalars/mod.rs index 669f1822b..172f6c9aa 100644 --- a/crates/sqlbuiltins/src/functions/scalars/mod.rs +++ b/crates/sqlbuiltins/src/functions/scalars/mod.rs @@ -110,15 +110,14 @@ fn try_from_u64_scalar(scalar: ScalarValue) -> Result { ))? as u64) } ScalarValue::UInt64(Some(value)) => Ok(value), - _ => { - return Err(BuiltinError::IncorrectType( - scalar.data_type(), - DataType::UInt64, - )) - } + _ => Err(BuiltinError::IncorrectType( + scalar.data_type(), + DataType::UInt64, + )), } } +#[allow(dead_code)] // just for merging order fn safe_up_cast_integer_scalar(value: i64) -> Result { if value < 0 { Err(BuiltinError::ParseError( @@ -131,6 +130,7 @@ fn safe_up_cast_integer_scalar(value: i64) -> Result { // get_nth_64_fn_arg extracts a string value (or tries to) from a // function argument; columns are always an error. +#[allow(dead_code)] // just for merging order fn get_nth_u64_fn_arg(input: &[ColumnarValue], idx: usize) -> Result { match input.get(idx) { Some(ColumnarValue::Scalar(value)) => try_from_u64_scalar(value.to_owned()), From 61834a416b6bbb5261b83d6b6bc11dab671d7328 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Sun, 24 Dec 2023 21:33:56 -0500 Subject: [PATCH 11/14] cleanup --- .../src/functions/scalars/hashing.rs | 28 ++++--------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/crates/sqlbuiltins/src/functions/scalars/hashing.rs b/crates/sqlbuiltins/src/functions/scalars/hashing.rs index 1eb9c6d6c..68e31d9b8 100644 --- a/crates/sqlbuiltins/src/functions/scalars/hashing.rs +++ b/crates/sqlbuiltins/src/functions/scalars/hashing.rs @@ -33,17 +33,9 @@ impl BuiltinScalarUDF for SipHash { ScalarValue, BuiltinError, > { - if let Some(value) = value { - let mut hasher = SipHasher24::new(); - - value.hash(&mut hasher); - - Ok(ScalarValue::UInt64(Some(hasher.finish()))) - } else { - Err(BuiltinError::ParseError( - "must have exactly one value to hash".to_string(), - )) - } + let mut hasher = SipHasher24::new(); + value.hash(&mut hasher); + Ok(ScalarValue::UInt64(Some(hasher.finish()))) })?) }), }; @@ -82,17 +74,9 @@ impl BuiltinScalarUDF for FnvHash { ScalarValue, BuiltinError, > { - if let Some(value) = value { - let mut hasher = FnvHasher::default(); - - value.hash(&mut hasher); - - Ok(ScalarValue::UInt64(Some(hasher.finish()))) - } else { - Err(BuiltinError::ParseError( - "must have exactly one value to hash".to_string(), - )) - } + let mut hasher = FnvHasher::default(); + value.hash(&mut hasher); + Ok(ScalarValue::UInt64(Some(hasher.finish()))) })?) }), }; From 6a0eea3c873f026eb856b30664dde293c949ab71 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Mon, 25 Dec 2023 02:56:24 -0500 Subject: [PATCH 12/14] reject missing expected value --- crates/sqlbuiltins/src/functions/scalars/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/sqlbuiltins/src/functions/scalars/mod.rs b/crates/sqlbuiltins/src/functions/scalars/mod.rs index 172f6c9aa..838f30f44 100644 --- a/crates/sqlbuiltins/src/functions/scalars/mod.rs +++ b/crates/sqlbuiltins/src/functions/scalars/mod.rs @@ -74,7 +74,7 @@ fn get_nth_scalar_value( )?)) } }, - None => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(true)))), + None => Err(BuiltinError::MissingValueAtIndex(idx)), } } From e373eb6b6fb53bcbe4b3e85d9102d32da076d388 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Mon, 25 Dec 2023 09:22:06 -0500 Subject: [PATCH 13/14] fix --- crates/sqlbuiltins/src/functions/scalars/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/sqlbuiltins/src/functions/scalars/mod.rs b/crates/sqlbuiltins/src/functions/scalars/mod.rs index 838f30f44..fef3df516 100644 --- a/crates/sqlbuiltins/src/functions/scalars/mod.rs +++ b/crates/sqlbuiltins/src/functions/scalars/mod.rs @@ -74,7 +74,7 @@ fn get_nth_scalar_value( )?)) } }, - None => Err(BuiltinError::MissingValueAtIndex(idx)), + None => Err(BuiltinError::MissingValueAtIndex(n)), } } From 6835f6566d40077580eea8175115ba6abc6f94c4 Mon Sep 17 00:00:00 2001 From: tycho garen Date: Wed, 27 Dec 2023 11:39:36 -0500 Subject: [PATCH 14/14] add test case --- tests/tests/test_functions.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 tests/tests/test_functions.py diff --git a/tests/tests/test_functions.py b/tests/tests/test_functions.py new file mode 100644 index 000000000..1bb2fba8b --- /dev/null +++ b/tests/tests/test_functions.py @@ -0,0 +1,19 @@ +import psycopg2 + +from fixtures.glaredb import glaredb_connection, debug_path + + +def test_scalar_parsing( + glaredb_connection: psycopg2.extensions.connection, +): + for operation in ["create table t (x int);", "insert into t values (1), (2), (3);"]: + with glaredb_connection.cursor() as cur: + cur.execute(operation) + + with glaredb_connection.cursor() as cur: + cur.execute("select siphash(x) as actual, siphash(arrow_cast(1, 'Int32')) as one from t") + result = cur.fetchall() + + assert result[0][0] == result[0][1] + assert result[1][0] != result[1][1] + assert result[2][0] != result[2][1]