Skip to content

Commit

Permalink
add version and context (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Nov 17, 2023
1 parent b3b352b commit 04df158
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 38 deletions.
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ members = [
]

[workspace.dependencies]
polars = { version = "0.34", default-features = false }
polars-core = { version = "0.34", default-features = false }
polars-ffi = { version = "0.34", default-features = false }
polars-plan = { version = "0.34", default-feautres = false }
polars-lazy = { version = "0.34", default-features = false }
polars = { version = "0.35", default-features = false }
polars-core = { version = "0.35", default-features = false }
polars-ffi = { version = "0.35", default-features = false }
polars-plan = { version = "0.35", default-feautres = false }
polars-lazy = { version = "0.35", default-features = false }
1 change: 1 addition & 0 deletions example/derive_expression/expression_lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ crate-type = ["cdylib"]
polars = { workspace = true, features = ["fmt", "dtype-date"], default-features = false }
pyo3-polars = { version = "*", path = "../../../pyo3-polars", features = ["derive"] }
serde = { version = "1", features = ["derive"] }
rayon = "1.7.0"

[target.'cfg(target_os = "linux")'.dependencies]
jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] }
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, expr: pl.Expr):
self._expr = expr

def pig_latinnify(self, capitalize: bool = False) -> pl.Expr:
return self._expr._register_plugin(
return self._expr.register_plugin(
lib=lib,
symbol="pig_latinnify",
is_elementwise=True,
Expand All @@ -28,7 +28,7 @@ def append_args(
"""
This example shows how arguments other than `Series` can be used.
"""
return self._expr._register_plugin(
return self._expr.register_plugin(
lib=lib,
args=[],
kwargs={
Expand All @@ -48,15 +48,15 @@ def __init__(self, expr: pl.Expr):
self._expr = expr

def hamming_distance(self, other: IntoExpr) -> pl.Expr:
return self._expr._register_plugin(
return self._expr.register_plugin(
lib=lib,
args=[other],
symbol="hamming_distance",
is_elementwise=True,
)

def jaccard_similarity(self, other: IntoExpr) -> pl.Expr:
return self._expr._register_plugin(
return self._expr.register_plugin(
lib=lib,
args=[other],
symbol="jaccard_similarity",
Expand All @@ -70,7 +70,7 @@ def haversine(
end_lat: IntoExpr,
end_long: IntoExpr,
) -> pl.Expr:
return self._expr._register_plugin(
return self._expr.register_plugin(
lib=lib,
args=[start_lat, start_long, end_lat, end_long],
symbol="haversine",
Expand All @@ -85,8 +85,19 @@ def __init__(self, expr: pl.Expr):
self._expr = expr

def is_leap_year(self) -> pl.Expr:
return self._expr._register_plugin(
return self._expr.register_plugin(
lib=lib,
symbol="is_leap_year",
is_elementwise=True,
)

@pl.api.register_expr_namespace("panic")
class Panic:
def __init__(self, expr: pl.Expr):
self._expr = expr

def panic(self) -> pl.Expr:
return self._expr.register_plugin(
lib=lib,
symbol="panic",
)
63 changes: 62 additions & 1 deletion example/derive_expression/expression_lib/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use polars::prelude::*;
use polars_plan::dsl::FieldsMapper;
use pyo3_polars::derive::polars_expr;
use pyo3_polars::derive::{polars_expr, CallerContext};
use pyo3_polars::export::polars_core::POOL;
use serde::Deserialize;
use std::fmt::Write;

Expand Down Expand Up @@ -31,6 +32,61 @@ fn pig_latinnify(inputs: &[Series], kwargs: PigLatinKwargs) -> PolarsResult<Seri
Ok(out.into_series())
}

fn split_offsets(len: usize, n: usize) -> Vec<(usize, usize)> {
if n == 1 {
vec![(0, len)]
} else {
let chunk_size = len / n;

(0..n)
.map(|partition| {
let offset = partition * chunk_size;
let len = if partition == (n - 1) {
len - offset
} else {
chunk_size
};
(partition * chunk_size, len)
})
.collect()
}
}

/// This expression will run in parallel if the `context` allows it.
#[polars_expr(output_type=Utf8)]
fn pig_latinnify_with_paralellism(
inputs: &[Series],
context: CallerContext,
kwargs: PigLatinKwargs,
) -> PolarsResult<Series> {
use rayon::prelude::*;
let ca = inputs[0].utf8()?;

if context.parallel() {
let out: Utf8Chunked =
ca.apply_to_buffer(|value, output| pig_latin_str(value, kwargs.capitalize, output));
Ok(out.into_series())
} else {
POOL.install(|| {
let n_threads = POOL.current_num_threads();
let splits = split_offsets(ca.len(), n_threads);

let chunks: Vec<_> = splits
.into_par_iter()
.map(|(offset, len)| {
let sliced = ca.slice(offset as i64, len);
let out = sliced.apply_to_buffer(|value, output| {
pig_latin_str(value, kwargs.capitalize, output)
});
out.downcast_iter().cloned().collect::<Vec<_>>()
})
.collect();

Ok(Utf8Chunked::from_chunk_iter(ca.name(), chunks.into_iter().flatten()).into_series())
})
}
}

#[polars_expr(output_type=Float64)]
fn jaccard_similarity(inputs: &[Series]) -> PolarsResult<Series> {
let a = inputs[0].list()?;
Expand Down Expand Up @@ -119,3 +175,8 @@ fn is_leap_year(input: &[Series]) -> PolarsResult<Series> {

Ok(out.into_series())
}

#[polars_expr(output_type=Boolean)]
fn panic(_input: &[Series]) -> PolarsResult<Series> {
todo!()
}
8 changes: 7 additions & 1 deletion example/derive_expression/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import polars as pl
from expression_lib import Language, Distance
from expression_lib import *
from datetime import date

df = pl.DataFrame(
Expand Down Expand Up @@ -45,3 +45,9 @@
)
except pl.ComputeError as e:
assert "the plugin failed with message" in str(e)


# For now test if we abort on panic.
out.with_columns(
pl.col("names").panic.panic()
)
105 changes: 81 additions & 24 deletions pyo3-polars-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ mod keywords;

use proc_macro::TokenStream;
use quote::quote;
use std::panic::UnwindSafe;
use std::sync::atomic::{AtomicBool, Ordering};
use syn::{parse_macro_input, FnArg};

Expand All @@ -15,7 +14,7 @@ fn insert_error_function() -> proc_macro2::TokenStream {
// Only expose the error retrieval function on the first expression.
if !is_init {
quote!(
pub use pyo3_polars::derive::get_last_error_message;
pub use pyo3_polars::derive::_polars_plugin_get_last_error_message;
)
} else {
proc_macro2::TokenStream::new()
Expand All @@ -24,7 +23,6 @@ fn insert_error_function() -> proc_macro2::TokenStream {

fn quote_call_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
quote!(

let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);

let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) {
Expand All @@ -44,6 +42,40 @@ fn quote_call_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::To
)
}

fn quote_call_context(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
quote!(
let context = *context;

// define the function
#ast

// call the function
let result: PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, context);
)
}

fn quote_call_context_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
quote!(
let context = *context;

let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);

let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) {
Ok(value) => value,
Err(err) => {
pyo3_polars::derive::_update_last_error(err);
return;
}
};

// define the function
#ast

// call the function
let result: PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, context, kwargs);
)
}

fn quote_call_no_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
quote!(
// define the function
Expand All @@ -57,7 +89,7 @@ fn quote_process_results() -> proc_macro2::TokenStream {
quote!(match result {
Ok(out) => {
// Update return value.
*return_value = polars_ffi::export_series(&out);
*return_value = polars_ffi::version_0::export_series(&out);
}
Err(err) => {
// Set latest error, but leave return value in empty state.
Expand All @@ -68,32 +100,45 @@ fn quote_process_results() -> proc_macro2::TokenStream {

fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {
// count how often the user define a kwargs argument.
let n_kwargs = ast
let args = ast
.sig
.inputs
.iter()
.filter(|fn_arg| {
.skip(1)
.map(|fn_arg| {
if let FnArg::Typed(pat) = fn_arg {
if let syn::Pat::Ident(pat) = pat.pat.as_ref() {
pat.ident.to_string() == "kwargs"
pat.ident.to_string()
} else {
false
panic!("expected an argument")
}
} else {
true
panic!("expected a type argument")
}
})
.count();
.collect::<Vec<_>>();

let fn_name = &ast.sig.ident;
let error_msg_fn = insert_error_function();

let quote_call = match n_kwargs {
// Get the tokenstream of the call logic.
let quote_call = match args.len() {
0 => quote_call_no_kwargs(&ast, fn_name),
1 => quote_call_kwargs(&ast, fn_name),
_ => unreachable!(), // arguments are unique
1 => match args[0].as_str() {
"kwargs" => quote_call_kwargs(&ast, fn_name),
"context" => quote_call_context(&ast, fn_name),
a => panic!("didn't expect argument {}", a),
},
2 => match (args[0].as_str(), args[1].as_str()) {
("context", "kwargs") => quote_call_context_kwargs(&ast, fn_name),
("kwargs", "context") => panic!("'kwargs', 'context' order should be reversed"),
(a, b) => panic!("didn't expect arguments {}, {}", a, b),
},
_ => panic!("didn't expect so many arguments"),
};

let quote_process_result = quote_process_results();
let fn_name = get_expression_function_name(fn_name);

quote!(
use pyo3_polars::export::*;
Expand All @@ -103,14 +148,15 @@ fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {
// create the outer public function
#[no_mangle]
pub unsafe extern "C" fn #fn_name (
e: *mut polars_ffi::SeriesExport,
e: *mut polars_ffi::version_0::SeriesExport,
input_len: usize,
kwargs_ptr: *const u8,
kwargs_len: usize,
return_value: *mut polars_ffi::SeriesExport
return_value: *mut polars_ffi::version_0::SeriesExport,
context: *mut polars_ffi::version_0::CallerContext
) {
let panic_result = std::panic::catch_unwind(move || {
let inputs = polars_ffi::import_series_buffer(e, input_len).unwrap();
let inputs = polars_ffi::version_0::import_series_buffer(e, input_len).unwrap();

#quote_call

Expand All @@ -119,17 +165,29 @@ fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {
});

if panic_result.is_err() {
// Set latest to panic and nullify return value;
*return_value = polars_ffi::SeriesExport::empty();
// Set latest to panic;
pyo3_polars::derive::_set_panic();
}

}
)
}

fn get_field_name(fn_name: &syn::Ident) -> syn::Ident {
syn::Ident::new(&format!("__polars_field_{}", fn_name), fn_name.span())
fn get_field_function_name(fn_name: &syn::Ident) -> syn::Ident {
syn::Ident::new(
&format!(
"_polars_plugin_field_{}",
fn_name,
),
fn_name.span(),
)
}

fn get_expression_function_name(fn_name: &syn::Ident) -> syn::Ident {
syn::Ident::new(
&format!("_polars_plugin_{}", fn_name),
fn_name.span(),
)
}

fn get_inputs() -> proc_macro2::TokenStream {
Expand All @@ -147,7 +205,7 @@ fn create_field_function(
fn_name: &syn::Ident,
dtype_fn_name: &syn::Ident,
) -> proc_macro2::TokenStream {
let map_field_name = get_field_name(fn_name);
let map_field_name = get_field_function_name(fn_name);
let inputs = get_inputs();

quote! (
Expand Down Expand Up @@ -175,8 +233,7 @@ fn create_field_function(
});

if panic_result.is_err() {
// Set latest to panic and nullify return value;
*return_value = polars_core::export::arrow::ffi::ArrowSchema::empty();
// Set latest to panic;
pyo3_polars::derive::_set_panic();
}
}
Expand All @@ -187,7 +244,7 @@ fn create_field_function_from_with_dtype(
fn_name: &syn::Ident,
dtype: syn::Ident,
) -> proc_macro2::TokenStream {
let map_field_name = get_field_name(fn_name);
let map_field_name = get_field_function_name(fn_name);
let inputs = get_inputs();

quote! (
Expand Down
Loading

0 comments on commit 04df158

Please sign in to comment.