diff --git a/Cargo.lock b/Cargo.lock index 31c0658f3..1c45e791f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8232,6 +8232,7 @@ dependencies = [ "arrow-cast", "async-openai", "async-trait", + "bson", "bytes", "catalog", "datafusion", diff --git a/crates/sqlbuiltins/Cargo.toml b/crates/sqlbuiltins/Cargo.toml index 121e489b1..17af60696 100644 --- a/crates/sqlbuiltins/Cargo.toml +++ b/crates/sqlbuiltins/Cargo.toml @@ -46,3 +46,4 @@ arrow-cast = { version = "50.0.0" } # MUST synchronize sync with the datafusion: lance-linalg = { git = "https://github.com/GlareDB/lance", branch = "df36" } # omits duckdb submodule jaq-interpret = "1.5.0" jaq-parse = "1.0.2" +bson = "2.11.0" diff --git a/crates/sqlbuiltins/src/errors.rs b/crates/sqlbuiltins/src/errors.rs index 9b5d400f8..a2b94220f 100644 --- a/crates/sqlbuiltins/src/errors.rs +++ b/crates/sqlbuiltins/src/errors.rs @@ -41,6 +41,9 @@ pub enum BuiltinError { #[error("serde_json: {0}")] SerdeJsonError(String), + #[error(transparent)] + BsonSer(#[from] bson::ser::Error), + #[error("jaq: {0}")] JaqError(String), } diff --git a/crates/sqlbuiltins/src/functions/mod.rs b/crates/sqlbuiltins/src/functions/mod.rs index 9a94e9bea..0314e0ee9 100644 --- a/crates/sqlbuiltins/src/functions/mod.rs +++ b/crates/sqlbuiltins/src/functions/mod.rs @@ -16,6 +16,7 @@ use datafusion::logical_expr::{ }; use once_cell::sync::Lazy; use protogen::metastore::types::catalog::FunctionType; +use scalars::bson2json::Bson2Json; use scalars::df_scalars::ArrowCastFunction; use scalars::hashing::{FnvHash, PartitionResults, SipHash}; use scalars::jaq::{JAQMatches, JAQSelect}; @@ -41,6 +42,7 @@ use scalars::{ConnectionId, Version}; use table::{BuiltinTableFuncs, TableFunc}; use self::alias_map::AliasMap; +use crate::functions::scalars::bson2json::Json2Bson; use crate::functions::scalars::df_scalars::{Decode, Encode, IsNan, NullIf}; use crate::functions::scalars::openai::OpenAIEmbed; use crate::functions::scalars::similarity::CosineSimilarity; @@ -246,6 +248,9 @@ impl FunctionRegistry { // JAQ functions Arc::new(JAQMatches::new()), Arc::new(JAQSelect::new()), + // Converters + Arc::new(Bson2Json), + Arc::new(Json2Bson), // Hashing/Partitioning Arc::new(SipHash), Arc::new(FnvHash), diff --git a/crates/sqlbuiltins/src/functions/scalars/bson2json.rs b/crates/sqlbuiltins/src/functions/scalars/bson2json.rs new file mode 100644 index 000000000..15f3fdf49 --- /dev/null +++ b/crates/sqlbuiltins/src/functions/scalars/bson2json.rs @@ -0,0 +1,153 @@ +use std::sync::Arc; + +use bson::Bson; +use catalog::session_catalog::SessionCatalog; +use datafusion::arrow::datatypes::DataType; +use datafusion::error::{DataFusionError, Result as DataFusionResult}; +use datafusion::logical_expr::expr::ScalarFunction; +use datafusion::logical_expr::{ + ColumnarValue, + ReturnTypeFunction, + ScalarFunctionImplementation, + ScalarUDF, + Signature, + TypeSignature, + Volatility, +}; +use datafusion::prelude::Expr; +use datafusion::scalar::ScalarValue; +use protogen::metastore::types::catalog::FunctionType; + +use super::apply_op_to_col_array; +use crate::errors::BuiltinError; +use crate::functions::{BuiltinScalarUDF, ConstBuiltinFunction}; + +pub struct Bson2Json; + +impl ConstBuiltinFunction for Bson2Json { + const NAME: &'static str = "bson2json"; + const DESCRIPTION: &'static str = "Converts a bson value to a (relaxed extended) json string"; + const EXAMPLE: &'static str = "bson2json(<value>)"; + const FUNCTION_TYPE: FunctionType = FunctionType::Scalar; + + fn signature(&self) -> Option<Signature> { + Some(Signature::one_of( + vec![TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![]), + TypeSignature::Exact(vec![DataType::Binary]), + TypeSignature::Exact(vec![DataType::LargeBinary]), + ])], + Volatility::Immutable, + )) + } +} + +impl Bson2Json { + fn convert(scalar: &ScalarValue) -> Result<ScalarValue, BuiltinError> { + match scalar { + ScalarValue::Binary(Some(v)) | ScalarValue::LargeBinary(Some(v)) => { + Ok(ScalarValue::new_utf8( + bson::de::from_slice::<Bson>(v) + .map_err(|e| DataFusionError::External(Box::new(e)))? + .into_relaxed_extjson() + .to_string(), + )) + } + ScalarValue::Binary(None) | ScalarValue::LargeBinary(None) => { + Ok(ScalarValue::Utf8(None)) + } + other => Err(BuiltinError::IncorrectType( + other.data_type(), + DataType::Binary, + )), + } + } +} + +impl BuiltinScalarUDF for Bson2Json { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec<Expr>) -> DataFusionResult<Expr> { + let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Utf8))); + let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |input| { + Ok(match input { + [] => ColumnarValue::Scalar(ScalarValue::new_utf8("{}")), + [ColumnarValue::Scalar(scalar)] => ColumnarValue::Scalar(Self::convert(scalar)?), + [ColumnarValue::Array(array)] => { + ColumnarValue::Array(apply_op_to_col_array(array, &Self::convert)?) + } + _ => unreachable!("bson2json expects exactly one argument"), + }) + }); + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(ScalarUDF::new( + Self::NAME, + &ConstBuiltinFunction::signature(self).unwrap(), + &return_type_fn, + &scalar_fn_impl, + )), + args, + ))) + } +} + + +pub struct Json2Bson; + +impl ConstBuiltinFunction for Json2Bson { + const NAME: &'static str = "json2bson"; + const DESCRIPTION: &'static str = "Converts a json string value to Bson"; + const EXAMPLE: &'static str = "json2bson(<value>)"; + const FUNCTION_TYPE: FunctionType = FunctionType::Scalar; + + fn signature(&self) -> Option<Signature> { + Some(Signature::one_of( + vec![TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![]), + TypeSignature::Exact(vec![DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8]), + ])], + Volatility::Immutable, + )) + } +} + +impl Json2Bson { + fn convert(scalar: &ScalarValue) -> Result<ScalarValue, BuiltinError> { + match scalar { + ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => { + Ok(ScalarValue::Binary(Some(bson::ser::to_vec( + &serde_json::from_str::<serde_json::Value>(v)?, + )?))) + } + ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) => Ok(ScalarValue::Binary(None)), + other => Err(BuiltinError::IncorrectType( + other.data_type(), + DataType::Utf8, + )), + } + } +} + +impl BuiltinScalarUDF for Json2Bson { + fn try_as_expr(&self, _: &SessionCatalog, args: Vec<Expr>) -> DataFusionResult<Expr> { + let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Binary))); + let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |input| { + Ok(match input { + [] => ColumnarValue::Scalar(ScalarValue::Binary(Some(Vec::new()))), + [ColumnarValue::Scalar(scalar)] => ColumnarValue::Scalar(Self::convert(scalar)?), + [ColumnarValue::Array(array)] => { + ColumnarValue::Array(apply_op_to_col_array(array, &Self::convert)?) + } + _ => unreachable!("json2bson expects exactly one argument"), + }) + }); + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(ScalarUDF::new( + Self::NAME, + &ConstBuiltinFunction::signature(self).unwrap(), + &return_type_fn, + &scalar_fn_impl, + )), + args, + ))) + } +} diff --git a/crates/sqlbuiltins/src/functions/scalars/jaq.rs b/crates/sqlbuiltins/src/functions/scalars/jaq.rs index 1708c5935..1f862eb40 100644 --- a/crates/sqlbuiltins/src/functions/scalars/jaq.rs +++ b/crates/sqlbuiltins/src/functions/scalars/jaq.rs @@ -89,8 +89,8 @@ impl ScalarUDFImpl for JAQSelect { get_nth_string_value( input, 0, - &|value: String| -> Result<ScalarValue, BuiltinError> { - let val: Value = serde_json::from_str(&value)?; + &|value: &String| -> Result<ScalarValue, BuiltinError> { + let val: Value = serde_json::from_str(value)?; let inputs = RcIter::new(core::iter::empty()); let output = filter @@ -197,8 +197,8 @@ impl ScalarUDFImpl for JAQMatches { get_nth_string_value( input, 0, - &|value: String| -> Result<ScalarValue, BuiltinError> { - let val: Value = serde_json::from_str(&value)?; + &|value: &String| -> Result<ScalarValue, BuiltinError> { + let val: Value = serde_json::from_str(value)?; let input = RcIter::new(core::iter::empty()); let output = filter.run((Ctx::new([], &input), Val::from(val))); diff --git a/crates/sqlbuiltins/src/functions/scalars/kdl.rs b/crates/sqlbuiltins/src/functions/scalars/kdl.rs index 2a1cca1ed..38ab0ed6a 100644 --- a/crates/sqlbuiltins/src/functions/scalars/kdl.rs +++ b/crates/sqlbuiltins/src/functions/scalars/kdl.rs @@ -87,7 +87,7 @@ impl ScalarUDFImpl for KDLSelect { get_nth_string_value( input, 0, - &|value: String| -> Result<ScalarValue, BuiltinError> { + &|value: &String| -> Result<ScalarValue, BuiltinError> { let sdoc: kdl::KdlDocument = value.parse().map_err(BuiltinError::KdlError)?; let out: Vec<&KdlNode> = sdoc @@ -200,7 +200,7 @@ impl ScalarUDFImpl for KDLMatches { get_nth_string_value( input, 0, - &|value: String| -> Result<ScalarValue, BuiltinError> { + &|value: &String| -> Result<ScalarValue, BuiltinError> { let doc: kdl::KdlDocument = value.parse().map_err(BuiltinError::KdlError)?; Ok(ScalarValue::Boolean(Some( diff --git a/crates/sqlbuiltins/src/functions/scalars/mod.rs b/crates/sqlbuiltins/src/functions/scalars/mod.rs index ded5fbee5..e6686e815 100644 --- a/crates/sqlbuiltins/src/functions/scalars/mod.rs +++ b/crates/sqlbuiltins/src/functions/scalars/mod.rs @@ -1,3 +1,4 @@ +pub mod bson2json; pub mod df_scalars; pub mod hashing; pub mod jaq; @@ -19,7 +20,6 @@ use crate::document; use crate::errors::BuiltinError; use crate::functions::{BuiltinFunction, BuiltinScalarUDF, ConstBuiltinFunction}; - pub struct ConnectionId; impl ConstBuiltinFunction for ConnectionId { @@ -67,11 +67,11 @@ impl BuiltinScalarUDF for Version { fn get_nth_scalar_value( input: &[ColumnarValue], n: usize, - op: &dyn Fn(ScalarValue) -> Result<ScalarValue, BuiltinError>, + op: &dyn Fn(&ScalarValue) -> Result<ScalarValue, BuiltinError>, ) -> Result<ColumnarValue, BuiltinError> { match input.get(n) { Some(input) => match input { - ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(op(scalar.clone())?)), + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(op(scalar)?)), ColumnarValue::Array(arr) => Ok(ColumnarValue::Array(apply_op_to_col_array(arr, op)?)), }, None => Err(BuiltinError::MissingValueAtIndex(n)), @@ -80,7 +80,7 @@ fn get_nth_scalar_value( fn apply_op_to_col_array( arr: &dyn Array, - op: &dyn Fn(ScalarValue) -> Result<ScalarValue, BuiltinError>, + op: &dyn Fn(&ScalarValue) -> Result<ScalarValue, BuiltinError>, ) -> Result<Arc<dyn Array>, BuiltinError> { let mut check_err: Result<(), BuiltinError> = Ok(()); @@ -104,7 +104,7 @@ fn apply_op_to_col_array( let iter = (0..arr.len()).filter_map(|idx| { let scalar_res = ScalarValue::try_from_array(arr, idx).map_err(BuiltinError::from); let scalar = filter_fn(&mut check_err, scalar_res)?; - filter_fn(&mut check_err, op(scalar)) + filter_fn(&mut check_err, op(&scalar)) }); // NB: ScalarValue::iter_to_array accepts an iterator over @@ -207,7 +207,7 @@ fn get_nth_string_fn_arg(input: &[ColumnarValue], idx: usize) -> Result<String, fn get_nth_string_value( input: &[ColumnarValue], n: usize, - op: &dyn Fn(String) -> Result<ScalarValue, BuiltinError>, + op: &dyn Fn(&String) -> Result<ScalarValue, BuiltinError>, ) -> Result<ColumnarValue, BuiltinError> { get_nth_scalar_value(input, n, &|scalar| -> Result<ScalarValue, BuiltinError> { match scalar { diff --git a/testdata/sqllogictests/functions/bson.slt b/testdata/sqllogictests/functions/bson.slt new file mode 100644 index 000000000..467bc30ee --- /dev/null +++ b/testdata/sqllogictests/functions/bson.slt @@ -0,0 +1,43 @@ +statement ok +CREATE TEMP TABLE bson_conversions (id int, json text, bson bytea); + +statement ok +INSERT INTO bson_conversions +VALUES + (0, '{"a":1}', json2bson('{"a":1}')), + (1, '{"b":2}', json2bson('{"b":2}')); + +query +SELECT jaq_select(json, '.a') +FROM bson_conversions +WHERE id = 0; +---- +1 + +query +SELECT jaq_select(bson2json(bson), '.a') +FROM bson_conversions +WHERE id = 0; +---- +1 + +query +SELECT jaq_select(bson2json(bson), '.a') +FROM bson_conversions +WHERE id = 1; +---- +NULL + +query +SELECT bson2json(bson) = json +FROM bson_conversions +WHERE id = 0 +---- +t + +query +SELECT json2bson(json) = bson +FROM bson_conversions +WHERE id = 0 +---- +t