diff --git a/Cargo.toml b/Cargo.toml index 7085acc..8d009fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,8 +8,8 @@ members = [ ] [workspace.dependencies] -polars = { version = "0.34", default-features = false } -polars-core = { version = "0.34", default-features = false } -polars-ffi = { version = "0.34", default-features = false } -polars-plan = { version = "0.34", default-feautres = false } -polars-lazy = { version = "0.34", default-features = false } +polars = { version = "0.35", default-features = false } +polars-core = { version = "0.35", default-features = false } +polars-ffi = { version = "0.35", default-features = false } +polars-plan = { version = "0.35", default-feautres = false } +polars-lazy = { version = "0.35", default-features = false } diff --git a/example/derive_expression/expression_lib/Cargo.toml b/example/derive_expression/expression_lib/Cargo.toml index 9dccc21..5958b62 100644 --- a/example/derive_expression/expression_lib/Cargo.toml +++ b/example/derive_expression/expression_lib/Cargo.toml @@ -12,6 +12,7 @@ crate-type = ["cdylib"] polars = { workspace = true, features = ["fmt", "dtype-date"], default-features = false } pyo3-polars = { version = "*", path = "../../../pyo3-polars", features = ["derive"] } serde = { version = "1", features = ["derive"] } +rayon = "1.7.0" [target.'cfg(target_os = "linux")'.dependencies] jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] } diff --git a/example/derive_expression/expression_lib/expression_lib/__init__.py b/example/derive_expression/expression_lib/expression_lib/__init__.py index 58b13eb..e37e881 100644 --- a/example/derive_expression/expression_lib/expression_lib/__init__.py +++ b/example/derive_expression/expression_lib/expression_lib/__init__.py @@ -11,7 +11,7 @@ def __init__(self, expr: pl.Expr): self._expr = expr def pig_latinnify(self, capitalize: bool = False) -> pl.Expr: - return self._expr._register_plugin( + return self._expr.register_plugin( lib=lib, symbol="pig_latinnify", is_elementwise=True, @@ -28,7 +28,7 @@ def append_args( """ This example shows how arguments other than `Series` can be used. """ - return self._expr._register_plugin( + return self._expr.register_plugin( lib=lib, args=[], kwargs={ @@ -48,7 +48,7 @@ def __init__(self, expr: pl.Expr): self._expr = expr def hamming_distance(self, other: IntoExpr) -> pl.Expr: - return self._expr._register_plugin( + return self._expr.register_plugin( lib=lib, args=[other], symbol="hamming_distance", @@ -56,7 +56,7 @@ def hamming_distance(self, other: IntoExpr) -> pl.Expr: ) def jaccard_similarity(self, other: IntoExpr) -> pl.Expr: - return self._expr._register_plugin( + return self._expr.register_plugin( lib=lib, args=[other], symbol="jaccard_similarity", @@ -70,7 +70,7 @@ def haversine( end_lat: IntoExpr, end_long: IntoExpr, ) -> pl.Expr: - return self._expr._register_plugin( + return self._expr.register_plugin( lib=lib, args=[start_lat, start_long, end_lat, end_long], symbol="haversine", @@ -85,8 +85,19 @@ def __init__(self, expr: pl.Expr): self._expr = expr def is_leap_year(self) -> pl.Expr: - return self._expr._register_plugin( + return self._expr.register_plugin( lib=lib, symbol="is_leap_year", is_elementwise=True, ) + +@pl.api.register_expr_namespace("panic") +class Panic: + def __init__(self, expr: pl.Expr): + self._expr = expr + + def panic(self) -> pl.Expr: + return self._expr.register_plugin( + lib=lib, + symbol="panic", + ) diff --git a/example/derive_expression/expression_lib/src/expressions.rs b/example/derive_expression/expression_lib/src/expressions.rs index ca764f2..8a34c1c 100644 --- a/example/derive_expression/expression_lib/src/expressions.rs +++ b/example/derive_expression/expression_lib/src/expressions.rs @@ -1,6 +1,7 @@ use polars::prelude::*; use polars_plan::dsl::FieldsMapper; -use pyo3_polars::derive::polars_expr; +use pyo3_polars::derive::{polars_expr, CallerContext}; +use pyo3_polars::export::polars_core::POOL; use serde::Deserialize; use std::fmt::Write; @@ -31,6 +32,61 @@ fn pig_latinnify(inputs: &[Series], kwargs: PigLatinKwargs) -> PolarsResult Vec<(usize, usize)> { + if n == 1 { + vec![(0, len)] + } else { + let chunk_size = len / n; + + (0..n) + .map(|partition| { + let offset = partition * chunk_size; + let len = if partition == (n - 1) { + len - offset + } else { + chunk_size + }; + (partition * chunk_size, len) + }) + .collect() + } +} + +/// This expression will run in parallel if the `context` allows it. +#[polars_expr(output_type=Utf8)] +fn pig_latinnify_with_paralellism( + inputs: &[Series], + context: CallerContext, + kwargs: PigLatinKwargs, +) -> PolarsResult { + use rayon::prelude::*; + let ca = inputs[0].utf8()?; + + if context.parallel() { + let out: Utf8Chunked = + ca.apply_to_buffer(|value, output| pig_latin_str(value, kwargs.capitalize, output)); + Ok(out.into_series()) + } else { + POOL.install(|| { + let n_threads = POOL.current_num_threads(); + let splits = split_offsets(ca.len(), n_threads); + + let chunks: Vec<_> = splits + .into_par_iter() + .map(|(offset, len)| { + let sliced = ca.slice(offset as i64, len); + let out = sliced.apply_to_buffer(|value, output| { + pig_latin_str(value, kwargs.capitalize, output) + }); + out.downcast_iter().cloned().collect::>() + }) + .collect(); + + Ok(Utf8Chunked::from_chunk_iter(ca.name(), chunks.into_iter().flatten()).into_series()) + }) + } +} + #[polars_expr(output_type=Float64)] fn jaccard_similarity(inputs: &[Series]) -> PolarsResult { let a = inputs[0].list()?; @@ -119,3 +175,8 @@ fn is_leap_year(input: &[Series]) -> PolarsResult { Ok(out.into_series()) } + +#[polars_expr(output_type=Boolean)] +fn panic(_input: &[Series]) -> PolarsResult { + todo!() +} diff --git a/example/derive_expression/run.py b/example/derive_expression/run.py index 78a0cca..58cf5c4 100644 --- a/example/derive_expression/run.py +++ b/example/derive_expression/run.py @@ -1,5 +1,5 @@ import polars as pl -from expression_lib import Language, Distance +from expression_lib import * from datetime import date df = pl.DataFrame( @@ -45,3 +45,9 @@ ) except pl.ComputeError as e: assert "the plugin failed with message" in str(e) + + +# For now test if we abort on panic. +out.with_columns( + pl.col("names").panic.panic() +) \ No newline at end of file diff --git a/pyo3-polars-derive/src/lib.rs b/pyo3-polars-derive/src/lib.rs index 25ca3eb..978930d 100644 --- a/pyo3-polars-derive/src/lib.rs +++ b/pyo3-polars-derive/src/lib.rs @@ -3,7 +3,6 @@ mod keywords; use proc_macro::TokenStream; use quote::quote; -use std::panic::UnwindSafe; use std::sync::atomic::{AtomicBool, Ordering}; use syn::{parse_macro_input, FnArg}; @@ -15,7 +14,7 @@ fn insert_error_function() -> proc_macro2::TokenStream { // Only expose the error retrieval function on the first expression. if !is_init { quote!( - pub use pyo3_polars::derive::get_last_error_message; + pub use pyo3_polars::derive::_polars_plugin_get_last_error_message; ) } else { proc_macro2::TokenStream::new() @@ -24,7 +23,6 @@ fn insert_error_function() -> proc_macro2::TokenStream { fn quote_call_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream { quote!( - let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len); let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) { @@ -44,6 +42,40 @@ fn quote_call_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::To ) } +fn quote_call_context(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream { + quote!( + let context = *context; + + // define the function + #ast + + // call the function + let result: PolarsResult = #fn_name(&inputs, context); + ) +} + +fn quote_call_context_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream { + quote!( + let context = *context; + + let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len); + + let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) { + Ok(value) => value, + Err(err) => { + pyo3_polars::derive::_update_last_error(err); + return; + } + }; + + // define the function + #ast + + // call the function + let result: PolarsResult = #fn_name(&inputs, context, kwargs); + ) +} + fn quote_call_no_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream { quote!( // define the function @@ -57,7 +89,7 @@ fn quote_process_results() -> proc_macro2::TokenStream { quote!(match result { Ok(out) => { // Update return value. - *return_value = polars_ffi::export_series(&out); + *return_value = polars_ffi::version_0::export_series(&out); } Err(err) => { // Set latest error, but leave return value in empty state. @@ -68,32 +100,45 @@ fn quote_process_results() -> proc_macro2::TokenStream { fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream { // count how often the user define a kwargs argument. - let n_kwargs = ast + let args = ast .sig .inputs .iter() - .filter(|fn_arg| { + .skip(1) + .map(|fn_arg| { if let FnArg::Typed(pat) = fn_arg { if let syn::Pat::Ident(pat) = pat.pat.as_ref() { - pat.ident.to_string() == "kwargs" + pat.ident.to_string() } else { - false + panic!("expected an argument") } } else { - true + panic!("expected a type argument") } }) - .count(); + .collect::>(); let fn_name = &ast.sig.ident; let error_msg_fn = insert_error_function(); - let quote_call = match n_kwargs { + // Get the tokenstream of the call logic. + let quote_call = match args.len() { 0 => quote_call_no_kwargs(&ast, fn_name), - 1 => quote_call_kwargs(&ast, fn_name), - _ => unreachable!(), // arguments are unique + 1 => match args[0].as_str() { + "kwargs" => quote_call_kwargs(&ast, fn_name), + "context" => quote_call_context(&ast, fn_name), + a => panic!("didn't expect argument {}", a), + }, + 2 => match (args[0].as_str(), args[1].as_str()) { + ("context", "kwargs") => quote_call_context_kwargs(&ast, fn_name), + ("kwargs", "context") => panic!("'kwargs', 'context' order should be reversed"), + (a, b) => panic!("didn't expect arguments {}, {}", a, b), + }, + _ => panic!("didn't expect so many arguments"), }; + let quote_process_result = quote_process_results(); + let fn_name = get_expression_function_name(fn_name); quote!( use pyo3_polars::export::*; @@ -103,14 +148,15 @@ fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream { // create the outer public function #[no_mangle] pub unsafe extern "C" fn #fn_name ( - e: *mut polars_ffi::SeriesExport, + e: *mut polars_ffi::version_0::SeriesExport, input_len: usize, kwargs_ptr: *const u8, kwargs_len: usize, - return_value: *mut polars_ffi::SeriesExport + return_value: *mut polars_ffi::version_0::SeriesExport, + context: *mut polars_ffi::version_0::CallerContext ) { let panic_result = std::panic::catch_unwind(move || { - let inputs = polars_ffi::import_series_buffer(e, input_len).unwrap(); + let inputs = polars_ffi::version_0::import_series_buffer(e, input_len).unwrap(); #quote_call @@ -119,8 +165,7 @@ fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream { }); if panic_result.is_err() { - // Set latest to panic and nullify return value; - *return_value = polars_ffi::SeriesExport::empty(); + // Set latest to panic; pyo3_polars::derive::_set_panic(); } @@ -128,8 +173,21 @@ fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream { ) } -fn get_field_name(fn_name: &syn::Ident) -> syn::Ident { - syn::Ident::new(&format!("__polars_field_{}", fn_name), fn_name.span()) +fn get_field_function_name(fn_name: &syn::Ident) -> syn::Ident { + syn::Ident::new( + &format!( + "_polars_plugin_field_{}", + fn_name, + ), + fn_name.span(), + ) +} + +fn get_expression_function_name(fn_name: &syn::Ident) -> syn::Ident { + syn::Ident::new( + &format!("_polars_plugin_{}", fn_name), + fn_name.span(), + ) } fn get_inputs() -> proc_macro2::TokenStream { @@ -147,7 +205,7 @@ fn create_field_function( fn_name: &syn::Ident, dtype_fn_name: &syn::Ident, ) -> proc_macro2::TokenStream { - let map_field_name = get_field_name(fn_name); + let map_field_name = get_field_function_name(fn_name); let inputs = get_inputs(); quote! ( @@ -175,8 +233,7 @@ fn create_field_function( }); if panic_result.is_err() { - // Set latest to panic and nullify return value; - *return_value = polars_core::export::arrow::ffi::ArrowSchema::empty(); + // Set latest to panic; pyo3_polars::derive::_set_panic(); } } @@ -187,7 +244,7 @@ fn create_field_function_from_with_dtype( fn_name: &syn::Ident, dtype: syn::Ident, ) -> proc_macro2::TokenStream { - let map_field_name = get_field_name(fn_name); + let map_field_name = get_field_function_name(fn_name); let inputs = get_inputs(); quote! ( diff --git a/pyo3-polars/src/derive.rs b/pyo3-polars/src/derive.rs index 8c9f4e0..8616a07 100644 --- a/pyo3-polars/src/derive.rs +++ b/pyo3-polars/src/derive.rs @@ -5,6 +5,10 @@ use serde::Deserialize; use std::cell::RefCell; use std::ffi::CString; +/// Gives the caller extra information on how to execute the expression. +pub use polars_ffi::version_0::CallerContext; + +/// A default opaque kwargs type. pub type DefaultKwargs = serde_pickle::Value; thread_local! { @@ -31,6 +35,13 @@ pub fn _set_panic() { } #[no_mangle] -pub unsafe extern "C" fn get_last_error_message() -> *const std::os::raw::c_char { +pub unsafe extern "C" fn _polars_plugin_get_last_error_message() -> *const std::os::raw::c_char { LAST_ERROR.with(|prev| prev.borrow_mut().as_ptr()) } + +#[no_mangle] +pub unsafe extern "C" fn _polars_plugin_get_version() -> u32 { + let (major, minor) = polars_ffi::get_version(); + // Stack bits together + ((major as u32) << 16) + minor as u32 +}