Skip to content

Commit

Permalink
fix: type_func
Browse files Browse the repository at this point in the history
  • Loading branch information
andysham authored Oct 6, 2023
1 parent 0165cb4 commit f84281e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
8 changes: 5 additions & 3 deletions example/derive_expression/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
"names": ["Richard", "Alice", "Bob"],
"moons": ["full", "half", "red"],
"dist_a": [[12, 32, 1], [], [1, -2]],
"dist_b": [[-12, 1], [43], [876, -45, 9]]
"dist_b": [[-12, 1], [43], [876, -45, 9]],
"floats": [5.6, -1245.8, 242.224]
})


out = df.with_columns(
pig_latin = pl.col("names").language.pig_latinnify()
pig_latin = pl.col("names").language.pig_latinnify(),
).with_columns(
hamming_dist = pl.col("names").dist.hamming_distance("pig_latin"),
jaccard_sim = pl.col("dist_a").dist.jaccard_similarity("dist_b")
jaccard_sim = pl.col("dist_a").dist.jaccard_similarity("dist_b"),
haversine = pl.col("floats").dist.haversine("floats", "floats", "floats", "floats"),
)

print(out)
9 changes: 6 additions & 3 deletions pyo3-polars-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,18 @@ fn get_inputs() -> proc_macro2::TokenStream {
)
}

fn create_field_function(fn_name: &syn::Ident) -> proc_macro2::TokenStream {
fn create_field_function(
fn_name: &syn::Ident,
dtype_fn_name: &syn::Ident
) -> proc_macro2::TokenStream {
let map_field_name = get_field_name(fn_name);
let inputs = get_inputs();

quote! (
#[no_mangle]
pub unsafe extern "C" fn #map_field_name(field: *mut polars_core::export::arrow::ffi::ArrowSchema, len: usize) -> polars_core::export::arrow::ffi::ArrowSchema {
#inputs;
let out = #fn_name(&inputs).unwrap();
let out = #dtype_fn_name(&inputs).unwrap();
polars_core::export::arrow::ffi::export_field_to_c(&out.to_arrow())
}
)
Expand Down Expand Up @@ -81,7 +84,7 @@ pub fn polars_expr(attr: TokenStream, input: TokenStream) -> TokenStream {

let options = parse_macro_input!(attr as attr::ExprsFunctionOptions);
let expanded_field_fn = if let Some(fn_name) = options.output_type_fn {
create_field_function(&fn_name)
create_field_function(&ast.sig.ident, &fn_name)
} else if let Some(dtype) = options.output_dtype {
create_field_function_from_with_dtype(&ast.sig.ident, dtype)
} else {
Expand Down

0 comments on commit f84281e

Please sign in to comment.