From 0a55707833c668087efdbcfec2be392b2929be87 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Thu, 5 Sep 2024 10:09:01 -0400 Subject: [PATCH 1/2] Start setting up wasm --- crates/wasm-functions/Cargo.toml | 14 ++ crates/wasm-functions/src/lib.rs | 84 +++++++ crates/wasm-udfs/Cargo.toml | 10 + crates/wasm-udfs/src/lib.rs | 67 +++++ crates/wasmedge-factory/Cargo.toml | 15 ++ crates/wasmedge-factory/src/lib.rs | 383 +++++++++++++++++++++++++++++ crates/wasmedge-factory/src/udf.rs | 122 +++++++++ 7 files changed, 695 insertions(+) create mode 100644 crates/wasm-functions/Cargo.toml create mode 100644 crates/wasm-functions/src/lib.rs create mode 100644 crates/wasm-udfs/Cargo.toml create mode 100644 crates/wasm-udfs/src/lib.rs create mode 100644 crates/wasmedge-factory/Cargo.toml create mode 100644 crates/wasmedge-factory/src/lib.rs create mode 100644 crates/wasmedge-factory/src/udf.rs diff --git a/crates/wasm-functions/Cargo.toml b/crates/wasm-functions/Cargo.toml new file mode 100644 index 0000000..713a1f9 --- /dev/null +++ b/crates/wasm-functions/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "wasm-functions" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +arrow = "51.0.0" +wasm-udfs = { version = "0.1.0", path = "../wasm-udfs" } + +[dev-dependencies] +wasm-bindgen-test = "0.3.43" diff --git a/crates/wasm-functions/src/lib.rs b/crates/wasm-functions/src/lib.rs new file mode 100644 index 0000000..8dd11ea --- /dev/null +++ b/crates/wasm-functions/src/lib.rs @@ -0,0 +1,84 @@ +use arrow::array::{Array, ArrayRef, Float64Array}; +use arrow::error::ArrowError; +use std::sync::Arc; +use wasm_udfs::*; + +// ```bash +// cargo install wasm-bindgen-cli +// ``` + +// ```bash +// cargo test --target wasm32-unknown-unknown +// ``` + +// expose function f1 as external function +// add required bindgen, and required serialization/deserialization +export_udf_function!(f1); +// function should return error +export_udf_function!(f_return_error); +// function should panic +// export_udf_function!(f_panic); +// function should return arrow error +export_udf_function!(f_return_arrow_error); + +/// standard datafusion udf ... kind of +/// should return ArrayRef or ArrowError +fn f1(args: &[ArrayRef]) -> Result { + assert_eq!(2, args.len()); + + let base = args[0] + .as_any() + .downcast_ref::() + .expect("cast 0 failed"); + let exponent = args[1] + .as_any() + .downcast_ref::() + .expect("cast 1 failed"); + + assert_eq!(exponent.len(), base.len()); + + let array = base + .iter() + .zip(exponent.iter()) + .map(|(base, exponent)| match (base, exponent) { + (Some(base), Some(exponent)) => Some(base.powf(exponent)), + _ => None, + }) + .collect::(); + + // TODO: do we need arc here? + // only reason to stay to keep api same + // like datafusion udf's + Ok(Arc::new(array)) +} +/// function returns String Error +fn f_return_error(_args: &[ArrayRef]) -> Result { + Err("wasm function returned error".to_string()) +} + +/// function returns error +fn f_return_arrow_error(_args: &[ArrayRef]) -> Result { + Err(ArrowError::DivideByZero) +} + +// fn f_panic(_args: &[ArrayRef]) -> Result { +// panic!("wasm function panicked") +// } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ArrayRef, Float64Array}; + + use std::sync::Arc; + + #[wasm_bindgen_test::wasm_bindgen_test] + fn test_f1() { + let a: ArrayRef = Arc::new(Float64Array::from(vec![2.1, 3.1, 4.1, 5.1])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + let args = vec![a, b]; + let result = f1(&args).unwrap(); + + assert_eq!(4, result.len()) + } +} diff --git a/crates/wasm-udfs/Cargo.toml b/crates/wasm-udfs/Cargo.toml new file mode 100644 index 0000000..867c110 --- /dev/null +++ b/crates/wasm-udfs/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "wasm-udfs" +version = "0.1.0" +edition = "2021" + +[dependencies] +arrow = "51.0.0" +paste = "1.0.15" +wasmedge-bindgen = "0.4.1" +wasmedge-bindgen-macro = "0.4.1" diff --git a/crates/wasm-udfs/src/lib.rs b/crates/wasm-udfs/src/lib.rs new file mode 100644 index 0000000..ed22397 --- /dev/null +++ b/crates/wasm-udfs/src/lib.rs @@ -0,0 +1,67 @@ +use arrow::{ + array::{Array, ArrayRef, RecordBatch}, + datatypes::{Field, Schema, SchemaRef}, +}; +pub use paste; +use std::sync::Arc; +pub use wasmedge_bindgen; +pub use wasmedge_bindgen_macro; + +/// packs slice of arrays to a batch +/// with schema generated from array types +pub fn pack_array(args: &[ArrayRef]) -> RecordBatch { + let fields = args + .iter() + .enumerate() + .map(|(i, f)| Field::new(format!("c{}", i), f.data_type().clone(), false)) + .collect::>(); + + let schema = Arc::new(Schema::new(fields)); + + RecordBatch::try_new(schema, args.to_vec()).unwrap() +} + +/// packs slice of arrays to a batch +/// with external schema +pub fn pack_array_with_schema(args: &[ArrayRef], schema: SchemaRef) -> RecordBatch { + RecordBatch::try_new(schema, args.to_vec()).unwrap() +} + +/// creates a arrow ipc blob +pub fn to_ipc(schema: &Schema, batch: RecordBatch) -> Vec { + let blob = vec![]; + let mut stream_writer = arrow::ipc::writer::StreamWriter::try_new(blob, schema).unwrap(); + stream_writer.write(&batch).unwrap(); + + stream_writer.into_inner().unwrap() +} + +/// creates arrow arrays from arrow ipc blob +pub fn from_ipc(payload: &[u8]) -> RecordBatch { + let mut batch = arrow::ipc::reader::StreamReader::try_new(payload, None).unwrap(); + batch.next().unwrap().unwrap() +} + +/// exports wasm function and performs all required +/// arrow ipc serialization/deserialization +/// +/// macro will create new function prefixed with `__wasm_udf_` +/// +// TODO: make this a proc macro maybe ? +#[macro_export] +macro_rules! export_udf_function { + ($name:ident) => { + paste::item! { + #[wasmedge_bindgen_macro::wasmedge_bindgen] + pub fn [<__wasm_udf_$name>](payload: Vec) -> Result,String> { + let args_batch = from_ipc(&payload); + let result = $name(args_batch.columns()); + // let batch = pack_array(&vec![result]); + // to_ipc(&batch.schema(), batch) + result.map(|result| pack_array(&vec![result])) + .map(|batch| to_ipc(&batch.schema(), batch)) + .map_err(|e| e.to_string()) + } + } + }; +} diff --git a/crates/wasmedge-factory/Cargo.toml b/crates/wasmedge-factory/Cargo.toml new file mode 100644 index 0000000..372e008 --- /dev/null +++ b/crates/wasmedge-factory/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "wasmedge-factory" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-trait = "0.1.82" +datafusion = "38.0.0" +log = "0.4.22" +project-root = "0.2.2" +thiserror = "1.0.63" +tokio = "1.40.0" +wasm-udfs = { version = "0.1.0", path = "../wasm-udfs" } +wasmedge-sdk = "0.13.2" +weak-table = "0.3.2" diff --git a/crates/wasmedge-factory/src/lib.rs b/crates/wasmedge-factory/src/lib.rs new file mode 100644 index 0000000..e36e24c --- /dev/null +++ b/crates/wasmedge-factory/src/lib.rs @@ -0,0 +1,383 @@ +use std::{ + path::Path, + sync::{Arc, Weak}, +}; + +use datafusion::{ + arrow::datatypes::DataType, + common::exec_err, + error::{DataFusionError, Result}, + execution::context::{FunctionFactory, RegisterFunction, SessionState}, + logical_expr::{CreateFunction, DefinitionStatement, ScalarUDF}, +}; +use thiserror::Error; +use tokio::sync::Mutex; +use wasmedge_sdk::{config::ConfigBuilder, dock::VmDock, Module, VmBuilder}; +use weak_table::WeakValueHashMap; + +mod udf; + +type ModuleCache = Arc>>>; + +pub struct WasmFunctionFactory { + // note: + // https://github.com/WasmEdge/wasmedge-rust-sdk/issues/89 + // comments do not add up to VM interface, on top of it + // UDFs do not modify any state. leaving as it is for now + // may revert it later + modules: ModuleCache, +} + +#[async_trait::async_trait] +impl FunctionFactory for WasmFunctionFactory { + async fn create( + &self, + _state: &SessionState, + statement: CreateFunction, + ) -> Result { + let return_type = statement.return_type.expect("return type expected"); + let argument_types = statement + .args + .map(|args| { + args.into_iter() + .map(|a| a.data_type) + .collect::>() + }) + .unwrap_or_default(); + let declared_name = statement.name; + let (module_name, method_name) = match &statement.params.as_ { + Some(DefinitionStatement::SingleQuotedDef(path)) => Self::wasm_module_function(path)?, + None => return exec_err!("wasm function not defined "), + Some(f) => return exec_err!("wasm function incorrect {:?} ", f), + }; + + let vm = self.wasm_model_cache_or_load(&module_name).await?; + let f = crate::udf::WasmFunctionWrapper::new( + vm, + declared_name, + method_name, + argument_types, + return_type, + )?; + + Ok(RegisterFunction::Scalar(Arc::new(ScalarUDF::from(f)))) + } +} + +impl Default for WasmFunctionFactory { + fn default() -> Self { + WasmFunctionFactory { + modules: Arc::new(Mutex::new(WeakValueHashMap::new())), + } + } +} + +impl WasmFunctionFactory { + /// returns cached module or + /// loads, caches module and returns module + /// for given module path + async fn wasm_model_cache_or_load( + &self, + wasm_module_path: &str, + ) -> std::result::Result, WasmFunctionError> { + // caching key is bit primitive, but good enough for now + let mut modules = self.modules.lock().await; + // lets assume creation of new module will not take too long + // and lock will be kept for a very short period of time, + // good enough for now + match modules.get(wasm_module_path) { + Some(module) => { + log::debug!("return cached VM for wasm_module={}", wasm_module_path); + Ok(module.clone()) + } + None => { + log::debug!("no cached VM for wasm_module={}", wasm_module_path); + let module = Self::wasm_model_load(wasm_module_path)?; + modules.insert(wasm_module_path.to_string(), module.clone()); + Ok(module) + } + } + } + + fn wasm_module_function(s: &str) -> Result<(String, String)> { + match s.split('!').collect::>()[..] { + [module, method] if !module.is_empty() && !method.is_empty() => { + Ok((module.to_string(), method.to_string())) + } + _ => exec_err!("bad module/method format"), + } + } + + fn wasm_model_load(wasm_module: &str) -> std::result::Result, WasmFunctionError> { + log::debug!("producing new VM for wasm_module={}", wasm_module); + let file = Path::new(&wasm_module); + let module = if file.is_absolute() { + Module::from_file(None, wasm_module)? + } else { + let mut project_root = project_root::get_project_root() + .map_err(|e| WasmFunctionError::Execution(e.to_string()))?; + project_root.push(file); + Module::from_file(None, &project_root)? + }; + + // default configuration will do for now + let config = ConfigBuilder::default().build()?; + + let vm = VmBuilder::new() + .with_config(config) + .build()? + .register_module(None, module)?; + + Ok(Arc::new(VmDock::new(vm))) + } + #[cfg(test)] + fn module_cache(&self) -> ModuleCache { + self.modules.clone() + } +} + +#[derive(Error, Debug)] +pub enum WasmFunctionError { + #[error("WasmEdge Error: {0}")] + WasmEdgeError(#[from] Box), + #[error("Execution Error: {0}")] + Execution(String), +} + +impl From for DataFusionError { + fn from(e: WasmFunctionError) -> Self { + // will do for now + DataFusionError::Execution(e.to_string()) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use datafusion::{ + arrow::array::{ArrayRef, Float64Array, RecordBatch}, + assert_batches_eq, + execution::context::SessionContext, + }; + + use crate::WasmFunctionFactory; + + #[test] + fn test_module_function_split() { + let (module, method) = WasmFunctionFactory::wasm_module_function("module!method").unwrap(); + assert_eq!("module", module); + assert_eq!("method", method); + + assert!(WasmFunctionFactory::wasm_module_function("!method").is_err()); + } + #[tokio::test] + async fn should_handle_happy_path() -> datafusion::error::Result<()> { + let ctx = + SessionContext::new().with_function_factory(Arc::new(WasmFunctionFactory::default())); + + let a: ArrayRef = Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.1])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; + + ctx.register_batch("t", batch)?; + + let sql = r#" + CREATE FUNCTION f1(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f1' + "#; + + ctx.sql(sql).await?.show().await?; + + let result = ctx + .sql("select a, b, f1(a,b) from t") + .await? + .collect() + .await?; + let expected = vec![ + "+-----+-----+-------------------+", + "| a | b | f1(t.a,t.b) |", + "+-----+-----+-------------------+", + "| 2.0 | 2.0 | 4.0 |", + "| 3.0 | 3.0 | 27.0 |", + "| 4.0 | 4.0 | 256.0 |", + "| 5.0 | 5.1 | 3670.684197150057 |", + "+-----+-----+-------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[tokio::test] + async fn should_handle_error() -> datafusion::error::Result<()> { + let ctx = + SessionContext::new().with_function_factory(Arc::new(WasmFunctionFactory::default())); + + let sql = r#" + CREATE FUNCTION f2(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f_return_error' + "#; + + ctx.sql(sql).await?.show().await?; + + let result = ctx.sql("select f2(1.0,1.0)").await?.show().await; + + assert!(result.is_err()); + assert_eq!( + "Execution error: [Wasm Invocation] wasm function returned error", + result.err().unwrap().to_string() + ); + + Ok(()) + } + + #[tokio::test] + async fn should_handle_arrow_error() -> datafusion::error::Result<()> { + let ctx = + SessionContext::new().with_function_factory(Arc::new(WasmFunctionFactory::default())); + + let sql = r#" + CREATE FUNCTION f2(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f_return_arrow_error' + "#; + + ctx.sql(sql).await?.show().await?; + + let result = ctx.sql("select f2(1.0,1.0)").await?.show().await; + + assert!(result.is_err()); + assert_eq!( + "Execution error: [Wasm Invocation] Divide by zero error", + result.err().unwrap().to_string() + ); + + Ok(()) + } + + #[tokio::test] + #[ignore = "WasmEdge does not handle panic after latest change"] + async fn should_handle_panic() -> datafusion::error::Result<()> { + let ctx = + SessionContext::new().with_function_factory(Arc::new(WasmFunctionFactory::default())); + + let sql = r#" + CREATE FUNCTION f1(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f1' + "#; + // we register good function to verify that panich + // will not put vm to some unexpected state + ctx.sql(sql).await?.show().await?; + + let sql = r#" + CREATE FUNCTION f3(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f_panic' + "#; + + ctx.sql(sql).await?.show().await?; + + let result = ctx.sql("select f3(1.0,1.0)").await?.show().await; + + assert!(result.is_err()); + assert_eq!( + "Execution error: [Wasm Invocation Panic] unreachable", + result.err().unwrap().to_string() + ); + let result = ctx.sql("select f1(1.0,1.0)").await?.collect().await?; + let expected = vec![ + "+---------------------------+", + "| f1(Float64(1),Float64(1)) |", + "+---------------------------+", + "| 1.0 |", + "+---------------------------+", + ]; + + assert_batches_eq!(expected, &result); + Ok(()) + } + + #[tokio::test] + async fn should_create_drop_function() -> datafusion::error::Result<()> { + let function_factory = Arc::new(WasmFunctionFactory::default()); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); + + let sql = r#" + CREATE FUNCTION f1(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f1' + "#; + + ctx.sql(sql).await?.show().await?; + + let sql = r#" + CREATE FUNCTION f2(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f_return_arrow_error' + "#; + + ctx.sql(sql).await?.show().await?; + + let result = ctx.sql("select f1(2.0,2.0)").await?.collect().await?; + let expected = vec![ + "+---------------------------+", + "| f1(Float64(2),Float64(2)) |", + "+---------------------------+", + "| 4.0 |", + "+---------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + // we should have one modules caching + assert_eq!(1, function_factory.module_cache().lock().await.len()); + + let sql = r#" + DROP FUNCTION f1 + "#; + + ctx.sql(sql).await?.show().await?; + + let sql = r#" + DROP FUNCTION f2 + "#; + + ctx.sql(sql).await?.show().await?; + + // we should have none modules cached + // weak hashmap should drop VM after last function + // has been dropped. + // note, weak hash map is lazy to drop + assert_eq!( + 0, + function_factory + .module_cache() + .lock() + .await + .keys() + .collect::>() + .len() + ); + + Ok(()) + } +} + +// #[cfg(test)] +// #[ctor::ctor] +// fn init() { +// // Enable RUST_LOG logging configuration for test +// let _ = env_logger::builder().is_test(true).try_init(); +// } diff --git a/crates/wasmedge-factory/src/udf.rs b/crates/wasmedge-factory/src/udf.rs new file mode 100644 index 0000000..6b4bc89 --- /dev/null +++ b/crates/wasmedge-factory/src/udf.rs @@ -0,0 +1,122 @@ +use std::sync::Arc; + +use datafusion::{ + arrow::{ + array::ArrayRef, + datatypes::{DataType, Field, Schema, SchemaRef}, + }, + common::exec_err, + error::Result, + logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}, +}; +use wasm_udfs::{from_ipc, pack_array_with_schema, to_ipc}; +use wasmedge_sdk::dock::{Param, VmDock}; + +#[derive(Debug)] +pub(crate) struct WasmFunctionWrapper { + /// name which was used to in `CREATE FUNCTION` statement + declared_function_name: String, + /// wasm method to be called, can be found in `AS` part of the statement + // it would be much better if we could cache method handle + // but that is not currently supported by wasmedge sdk + wasm_method: String, + argument_schema: SchemaRef, + // TODO: function signature should be extracted from `CREATE FUNCTION` statement + signature: Signature, + return_type: DataType, + /// wasm VM which hosts module + vm: Arc, +} + +impl WasmFunctionWrapper { + pub(crate) fn new( + vm: Arc, + declared_function_name: String, + wasm_method: String, + argument_types: Vec, + return_type: DataType, + ) -> Result { + let fields = argument_types + .iter() + .enumerate() + .map(|(i, f)| Field::new(format!("c{}", i), f.clone(), false)) + .collect::>(); + + // we cache the schema + // as it will be used for every message + // passed between rust and wasm (not sure if we can avoid that) + let argument_schema = Arc::new(Schema::new(fields)); + + Ok(Self { + // prefix is not really needed but it looks cool :) + wasm_method: format!("__wasm_udf_{}", wasm_method), + declared_function_name, + signature: Signature::exact(argument_types, Volatility::Volatile), + return_type, + argument_schema, + vm, + }) + } +} + +impl ScalarUDFImpl for WasmFunctionWrapper { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.declared_function_name + } + + fn signature(&self) -> &datafusion::logical_expr::Signature { + &self.signature + } + + fn return_type( + &self, + _arg_types: &[datafusion::arrow::datatypes::DataType], + ) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke( + &self, + args: &[datafusion::logical_expr::ColumnarValue], + ) -> Result { + let arrays = ColumnarValue::values_to_arrays(args)?; + let batch = pack_array_with_schema(&arrays, self.argument_schema.clone()); + + let payload = to_ipc(&batch.schema(), batch); + let params = vec![Param::VecU8(&payload)]; + + let call_result = match self.vm.run_func(&self.wasm_method, params) { + Ok(result) => result, + // if wasm function panics it should get to this error + Err(e) => return exec_err!("[Wasm Invocation Panic] {}", e), + }; + + match call_result { + // function returned result + // in our case we expect only single result + // at position 0 + Ok(mut res) => { + // we should add errors to the protocol + let response = res.pop().unwrap().downcast::>().unwrap(); + let a = from_ipc(&response); + // aso we expect single column as the result + let result = a.column(0); + Ok(ColumnarValue::from(result.clone() as ArrayRef)) + } + // function returned error + Err(err) => { + exec_err!("[Wasm Invocation] {}", err) + } + } + } +} + +impl Drop for WasmFunctionWrapper { + fn drop(&mut self) { + log::debug!("drop wasm function, name={}", self.name()) + } +} From 41c34a5892b6a16d630c1567a0140749939986f1 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Fri, 6 Sep 2024 10:47:44 -0400 Subject: [PATCH 2/2] arrow-udf-wasm instead of wasmedge --- crates/wasmedge-factory/Cargo.toml | 3 +- crates/wasmedge-factory/src/lib.rs | 245 ++++++++++++++++++----------- crates/wasmedge-factory/src/udf.rs | 214 ++++++++++++------------- 3 files changed, 266 insertions(+), 196 deletions(-) diff --git a/crates/wasmedge-factory/Cargo.toml b/crates/wasmedge-factory/Cargo.toml index 372e008..f80dfed 100644 --- a/crates/wasmedge-factory/Cargo.toml +++ b/crates/wasmedge-factory/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +arrow-udf-wasm = "0.2.2" async-trait = "0.1.82" datafusion = "38.0.0" log = "0.4.22" @@ -11,5 +12,5 @@ project-root = "0.2.2" thiserror = "1.0.63" tokio = "1.40.0" wasm-udfs = { version = "0.1.0", path = "../wasm-udfs" } -wasmedge-sdk = "0.13.2" +# wasmedge-sdk = "0.13.2" weak-table = "0.3.2" diff --git a/crates/wasmedge-factory/src/lib.rs b/crates/wasmedge-factory/src/lib.rs index e36e24c..b7158bb 100644 --- a/crates/wasmedge-factory/src/lib.rs +++ b/crates/wasmedge-factory/src/lib.rs @@ -3,21 +3,83 @@ use std::{ sync::{Arc, Weak}, }; +use arrow_udf_wasm::Runtime; use datafusion::{ - arrow::datatypes::DataType, - common::exec_err, + arrow::{ + array::{ArrayRef, Float64Array}, + datatypes::DataType, + }, + common::{cast::as_float64_array, exec_err}, error::{DataFusionError, Result}, execution::context::{FunctionFactory, RegisterFunction, SessionState}, - logical_expr::{CreateFunction, DefinitionStatement, ScalarUDF}, + logical_expr::{ + create_udf, ColumnarValue, CreateFunction, DefinitionStatement, ScalarUDF, Volatility, + }, }; use thiserror::Error; use tokio::sync::Mutex; -use wasmedge_sdk::{config::ConfigBuilder, dock::VmDock, Module, VmBuilder}; +// use wasmedge_sdk::{config::ConfigBuilder, dock::VmDock, Module, VmBuilder}; use weak_table::WeakValueHashMap; mod udf; -type ModuleCache = Arc>>>; +// type ModuleCache = Arc>>>; + +fn test_udf() -> ScalarUDF { + // First, declare the actual implementation of the calculation + let pow = Arc::new(|args: &[ColumnarValue]| { + // in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to: + // 1. cast the values to the type we want + // 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result + + // this is guaranteed by DataFusion based on the function's signature. + assert_eq!(args.len(), 2); + + // Expand the arguments to arrays (this is simple, but inefficient for + // single constant values). + let args = ColumnarValue::values_to_arrays(args)?; + + // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! + let base = as_float64_array(&args[0]).expect("cast failed"); + let exponent = as_float64_array(&args[1]).expect("cast failed"); + + // The array lengths is guaranteed by DataFusion. We assert here to make it obvious. + assert_eq!(exponent.len(), base.len()); + + // 2. perform the computation + let array = base + .iter() + .zip(exponent.iter()) + .map(|(base, exponent)| { + match (base, exponent) { + // in arrow, any value can be null. + // Here we decide to make our UDF to return null when either base or exponent is null. + (Some(base), Some(exponent)) => Some(base.powf(exponent)), + _ => None, + } + }) + .collect::(); + + // `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!) + // `Arc` because arrays are immutable, thread-safe, trait objects. + Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) + }); + + // Next: + // * give it a name so that it shows nicely when the plan is printed + // * declare what input it expects + // * declare its return type + let pow = create_udf( + "f1", + // expects two f64 + vec![DataType::Float64, DataType::Float64], + // returns f64 + Arc::new(DataType::Float64), + Volatility::Immutable, + pow, + ); + pow +} pub struct WasmFunctionFactory { // note: @@ -25,7 +87,7 @@ pub struct WasmFunctionFactory { // comments do not add up to VM interface, on top of it // UDFs do not modify any state. leaving as it is for now // may revert it later - modules: ModuleCache, + // modules: ModuleCache, } #[async_trait::async_trait] @@ -46,19 +108,26 @@ impl FunctionFactory for WasmFunctionFactory { .unwrap_or_default(); let declared_name = statement.name; let (module_name, method_name) = match &statement.params.as_ { - Some(DefinitionStatement::SingleQuotedDef(path)) => Self::wasm_module_function(path)?, + Some(DefinitionStatement::SingleQuotedDef(path)) => { + println!("Got create function path: {}", path); + Self::wasm_module_function(path)? + } None => return exec_err!("wasm function not defined "), Some(f) => return exec_err!("wasm function incorrect {:?} ", f), }; - let vm = self.wasm_model_cache_or_load(&module_name).await?; - let f = crate::udf::WasmFunctionWrapper::new( - vm, - declared_name, - method_name, - argument_types, - return_type, - )?; + // let rt = Runtime::new(binary); + + // let vm = self.wasm_model_cache_or_load(&module_name).await?; + // let f = crate::udf::WasmFunctionWrapper::new( + // vm, + // declared_name, + // method_name, + // argument_types, + // return_type, + // )?; + + let f = test_udf(); Ok(RegisterFunction::Scalar(Arc::new(ScalarUDF::from(f)))) } @@ -67,7 +136,7 @@ impl FunctionFactory for WasmFunctionFactory { impl Default for WasmFunctionFactory { fn default() -> Self { WasmFunctionFactory { - modules: Arc::new(Mutex::new(WeakValueHashMap::new())), + // modules: Arc::new(Mutex::new(WeakValueHashMap::new())), } } } @@ -76,28 +145,28 @@ impl WasmFunctionFactory { /// returns cached module or /// loads, caches module and returns module /// for given module path - async fn wasm_model_cache_or_load( - &self, - wasm_module_path: &str, - ) -> std::result::Result, WasmFunctionError> { - // caching key is bit primitive, but good enough for now - let mut modules = self.modules.lock().await; - // lets assume creation of new module will not take too long - // and lock will be kept for a very short period of time, - // good enough for now - match modules.get(wasm_module_path) { - Some(module) => { - log::debug!("return cached VM for wasm_module={}", wasm_module_path); - Ok(module.clone()) - } - None => { - log::debug!("no cached VM for wasm_module={}", wasm_module_path); - let module = Self::wasm_model_load(wasm_module_path)?; - modules.insert(wasm_module_path.to_string(), module.clone()); - Ok(module) - } - } - } + // async fn wasm_model_cache_or_load( + // &self, + // wasm_module_path: &str, + // ) -> std::result::Result, WasmFunctionError> { + // // caching key is bit primitive, but good enough for now + // let mut modules = self.modules.lock().await; + // // lets assume creation of new module will not take too long + // // and lock will be kept for a very short period of time, + // // good enough for now + // match modules.get(wasm_module_path) { + // Some(module) => { + // log::debug!("return cached VM for wasm_module={}", wasm_module_path); + // Ok(module.clone()) + // } + // None => { + // log::debug!("no cached VM for wasm_module={}", wasm_module_path); + // let module = Self::wasm_model_load(wasm_module_path)?; + // modules.insert(wasm_module_path.to_string(), module.clone()); + // Ok(module) + // } + // } + // } fn wasm_module_function(s: &str) -> Result<(String, String)> { match s.split('!').collect::>()[..] { @@ -108,48 +177,48 @@ impl WasmFunctionFactory { } } - fn wasm_model_load(wasm_module: &str) -> std::result::Result, WasmFunctionError> { - log::debug!("producing new VM for wasm_module={}", wasm_module); - let file = Path::new(&wasm_module); - let module = if file.is_absolute() { - Module::from_file(None, wasm_module)? - } else { - let mut project_root = project_root::get_project_root() - .map_err(|e| WasmFunctionError::Execution(e.to_string()))?; - project_root.push(file); - Module::from_file(None, &project_root)? - }; - - // default configuration will do for now - let config = ConfigBuilder::default().build()?; - - let vm = VmBuilder::new() - .with_config(config) - .build()? - .register_module(None, module)?; - - Ok(Arc::new(VmDock::new(vm))) - } - #[cfg(test)] - fn module_cache(&self) -> ModuleCache { - self.modules.clone() - } + // fn wasm_model_load(wasm_module: &str) -> std::result::Result, WasmFunctionError> { + // log::debug!("producing new VM for wasm_module={}", wasm_module); + // let file = Path::new(&wasm_module); + // let module = if file.is_absolute() { + // Module::from_file(None, wasm_module)? + // } else { + // let mut project_root = project_root::get_project_root() + // .map_err(|e| WasmFunctionError::Execution(e.to_string()))?; + // project_root.push(file); + // Module::from_file(None, &project_root)? + // }; + // + // // default configuration will do for now + // let config = ConfigBuilder::default().build()?; + // + // let vm = VmBuilder::new() + // .with_config(config) + // .build()? + // .register_module(None, module)?; + // + // Ok(Arc::new(VmDock::new(vm))) + // } + // #[cfg(test)] + // fn module_cache(&self) -> ModuleCache { + // self.modules.clone() + // } } -#[derive(Error, Debug)] -pub enum WasmFunctionError { - #[error("WasmEdge Error: {0}")] - WasmEdgeError(#[from] Box), - #[error("Execution Error: {0}")] - Execution(String), -} +// #[derive(Error, Debug)] +// pub enum WasmFunctionError { +// #[error("WasmEdge Error: {0}")] +// WasmEdgeError(#[from] Box), +// #[error("Execution Error: {0}")] +// Execution(String), +// } -impl From for DataFusionError { - fn from(e: WasmFunctionError) -> Self { - // will do for now - DataFusionError::Execution(e.to_string()) - } -} +// impl From for DataFusionError { +// fn from(e: WasmFunctionError) -> Self { +// // will do for now +// DataFusionError::Execution(e.to_string()) +// } +// } #[cfg(test)] mod test { @@ -342,7 +411,7 @@ mod test { assert_batches_eq!(expected, &result); // we should have one modules caching - assert_eq!(1, function_factory.module_cache().lock().await.len()); + // assert_eq!(1, function_factory.module_cache().lock().await.len()); let sql = r#" DROP FUNCTION f1 @@ -360,16 +429,16 @@ mod test { // weak hashmap should drop VM after last function // has been dropped. // note, weak hash map is lazy to drop - assert_eq!( - 0, - function_factory - .module_cache() - .lock() - .await - .keys() - .collect::>() - .len() - ); + // assert_eq!( + // 0, + // function_factory + // .module_cache() + // .lock() + // .await + // .keys() + // .collect::>() + // .len() + // ); Ok(()) } diff --git a/crates/wasmedge-factory/src/udf.rs b/crates/wasmedge-factory/src/udf.rs index 6b4bc89..3acf023 100644 --- a/crates/wasmedge-factory/src/udf.rs +++ b/crates/wasmedge-factory/src/udf.rs @@ -10,113 +10,113 @@ use datafusion::{ logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}, }; use wasm_udfs::{from_ipc, pack_array_with_schema, to_ipc}; -use wasmedge_sdk::dock::{Param, VmDock}; +// use wasmedge_sdk::dock::{Param, VmDock}; -#[derive(Debug)] -pub(crate) struct WasmFunctionWrapper { - /// name which was used to in `CREATE FUNCTION` statement - declared_function_name: String, - /// wasm method to be called, can be found in `AS` part of the statement - // it would be much better if we could cache method handle - // but that is not currently supported by wasmedge sdk - wasm_method: String, - argument_schema: SchemaRef, - // TODO: function signature should be extracted from `CREATE FUNCTION` statement - signature: Signature, - return_type: DataType, - /// wasm VM which hosts module - vm: Arc, -} +// #[derive(Debug)] +// pub(crate) struct WasmFunctionWrapper { +// /// name which was used to in `CREATE FUNCTION` statement +// declared_function_name: String, +// /// wasm method to be called, can be found in `AS` part of the statement +// // it would be much better if we could cache method handle +// // but that is not currently supported by wasmedge sdk +// wasm_method: String, +// argument_schema: SchemaRef, +// // TODO: function signature should be extracted from `CREATE FUNCTION` statement +// signature: Signature, +// return_type: DataType, +// /// wasm VM which hosts module +// vm: Arc, +// } -impl WasmFunctionWrapper { - pub(crate) fn new( - vm: Arc, - declared_function_name: String, - wasm_method: String, - argument_types: Vec, - return_type: DataType, - ) -> Result { - let fields = argument_types - .iter() - .enumerate() - .map(|(i, f)| Field::new(format!("c{}", i), f.clone(), false)) - .collect::>(); +// impl WasmFunctionWrapper { +// pub(crate) fn new( +// vm: Arc, +// declared_function_name: String, +// wasm_method: String, +// argument_types: Vec, +// return_type: DataType, +// ) -> Result { +// let fields = argument_types +// .iter() +// .enumerate() +// .map(|(i, f)| Field::new(format!("c{}", i), f.clone(), false)) +// .collect::>(); +// +// // we cache the schema +// // as it will be used for every message +// // passed between rust and wasm (not sure if we can avoid that) +// let argument_schema = Arc::new(Schema::new(fields)); +// +// Ok(Self { +// // prefix is not really needed but it looks cool :) +// wasm_method: format!("__wasm_udf_{}", wasm_method), +// declared_function_name, +// signature: Signature::exact(argument_types, Volatility::Volatile), +// return_type, +// argument_schema, +// vm, +// }) +// } +// } - // we cache the schema - // as it will be used for every message - // passed between rust and wasm (not sure if we can avoid that) - let argument_schema = Arc::new(Schema::new(fields)); - - Ok(Self { - // prefix is not really needed but it looks cool :) - wasm_method: format!("__wasm_udf_{}", wasm_method), - declared_function_name, - signature: Signature::exact(argument_types, Volatility::Volatile), - return_type, - argument_schema, - vm, - }) - } -} - -impl ScalarUDFImpl for WasmFunctionWrapper { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - &self.declared_function_name - } - - fn signature(&self) -> &datafusion::logical_expr::Signature { - &self.signature - } - - fn return_type( - &self, - _arg_types: &[datafusion::arrow::datatypes::DataType], - ) -> Result { - Ok(self.return_type.clone()) - } - - fn invoke( - &self, - args: &[datafusion::logical_expr::ColumnarValue], - ) -> Result { - let arrays = ColumnarValue::values_to_arrays(args)?; - let batch = pack_array_with_schema(&arrays, self.argument_schema.clone()); - - let payload = to_ipc(&batch.schema(), batch); - let params = vec![Param::VecU8(&payload)]; - - let call_result = match self.vm.run_func(&self.wasm_method, params) { - Ok(result) => result, - // if wasm function panics it should get to this error - Err(e) => return exec_err!("[Wasm Invocation Panic] {}", e), - }; - - match call_result { - // function returned result - // in our case we expect only single result - // at position 0 - Ok(mut res) => { - // we should add errors to the protocol - let response = res.pop().unwrap().downcast::>().unwrap(); - let a = from_ipc(&response); - // aso we expect single column as the result - let result = a.column(0); - Ok(ColumnarValue::from(result.clone() as ArrayRef)) - } - // function returned error - Err(err) => { - exec_err!("[Wasm Invocation] {}", err) - } - } - } -} - -impl Drop for WasmFunctionWrapper { - fn drop(&mut self) { - log::debug!("drop wasm function, name={}", self.name()) - } -} +// impl ScalarUDFImpl for WasmFunctionWrapper { +// fn as_any(&self) -> &dyn std::any::Any { +// self +// } +// +// fn name(&self) -> &str { +// &self.declared_function_name +// } +// +// fn signature(&self) -> &datafusion::logical_expr::Signature { +// &self.signature +// } +// +// fn return_type( +// &self, +// _arg_types: &[datafusion::arrow::datatypes::DataType], +// ) -> Result { +// Ok(self.return_type.clone()) +// } +// +// fn invoke( +// &self, +// args: &[datafusion::logical_expr::ColumnarValue], +// ) -> Result { +// let arrays = ColumnarValue::values_to_arrays(args)?; +// let batch = pack_array_with_schema(&arrays, self.argument_schema.clone()); +// +// let payload = to_ipc(&batch.schema(), batch); +// let params = vec![Param::VecU8(&payload)]; +// +// let call_result = match self.vm.run_func(&self.wasm_method, params) { +// Ok(result) => result, +// // if wasm function panics it should get to this error +// Err(e) => return exec_err!("[Wasm Invocation Panic] {}", e), +// }; +// +// match call_result { +// // function returned result +// // in our case we expect only single result +// // at position 0 +// Ok(mut res) => { +// // we should add errors to the protocol +// let response = res.pop().unwrap().downcast::>().unwrap(); +// let a = from_ipc(&response); +// // aso we expect single column as the result +// let result = a.column(0); +// Ok(ColumnarValue::from(result.clone() as ArrayRef)) +// } +// // function returned error +// Err(err) => { +// exec_err!("[Wasm Invocation] {}", err) +// } +// } +// } +// } +// +// impl Drop for WasmFunctionWrapper { +// fn drop(&mut self) { +// log::debug!("drop wasm function, name={}", self.name()) +// } +// }