diff --git a/README.md b/README.md index 57ebafb..1942da3 100644 --- a/README.md +++ b/README.md @@ -13,21 +13,35 @@ The idea is that you define an expression in another Rust crate with a proc_macr That macro can have the following attributes: - `output_type` -> to define the output type of that expression -- `type_func` -> to define a function that computes the output type based on input types. +- `output_type_func` -> to define a function that computes the output type based on input types. Here is an example of a `String` conversion expression that converts any string to [pig latin](https://en.wikipedia.org/wiki/Pig_Latin): ```rust -fn pig_latin_str(value: &str, output: &mut String) { +fn pig_latin_str(value: &str, capitalize: bool, output: &mut String) { if let Some(first_char) = value.chars().next() { - write!(output, "{}{}ay", &value[1..], first_char).unwrap() + if capitalize { + for c in value.chars().skip(1).map(|char| char.to_uppercase()) { + write!(output, "{c}").unwrap() + } + write!(output, "AY").unwrap() + } else { + let offset = first_char.len_utf8(); + write!(output, "{}{}ay", &value[offset..], first_char).unwrap() + } } } +#[derive(Deserialize)] +struct PigLatinKwargs { + capitalize: bool, +} + #[polars_expr(output_type=Utf8)] -fn pig_latinnify(inputs: &[Series]) -> PolarsResult { +fn pig_latinnify(inputs: &[Series], kwargs: PigLatinKwargs) -> PolarsResult { let ca = inputs[0].utf8()?; - let out: Utf8Chunked = ca.apply_to_buffer(pig_latin_str); + let out: Utf8Chunked = + ca.apply_to_buffer(|value, output| pig_latin_str(value, kwargs.capitalize, output)); Ok(out.into_series()) } ``` @@ -46,11 +60,12 @@ class Language: def __init__(self, expr: pl.Expr): self._expr = expr - def pig_latinnify(self) -> pl.Expr: + def pig_latinnify(self, capatilize: bool = False) -> pl.Expr: return self._expr._register_plugin( lib=lib, symbol="pig_latinnify", is_elementwise=True, + kwargs={"capitalize": capatilize} ) ``` @@ -58,7 +73,7 @@ Compile/ship and then it is ready to use: ```python import polars as pl -from expression_lib import Language +import expression_lib df = pl.DataFrame({ "names": ["Richard", "Alice", "Bob"], diff --git a/example/derive_expression/expression_lib/expression_lib/__init__.py b/example/derive_expression/expression_lib/expression_lib/__init__.py index 9b3be74..58b13eb 100644 --- a/example/derive_expression/expression_lib/expression_lib/__init__.py +++ b/example/derive_expression/expression_lib/expression_lib/__init__.py @@ -10,11 +10,12 @@ class Language: def __init__(self, expr: pl.Expr): self._expr = expr - def pig_latinnify(self) -> pl.Expr: + def pig_latinnify(self, capitalize: bool = False) -> pl.Expr: return self._expr._register_plugin( lib=lib, symbol="pig_latinnify", is_elementwise=True, + kwargs={"capitalize": capitalize}, ) def append_args( @@ -77,12 +78,12 @@ def haversine( cast_to_supertypes=True, ) + @pl.api.register_expr_namespace("date_util") class DateUtil: def __init__(self, expr: pl.Expr): self._expr = expr - def is_leap_year(self) -> pl.Expr: return self._expr._register_plugin( lib=lib, diff --git a/example/derive_expression/expression_lib/src/expressions.rs b/example/derive_expression/expression_lib/src/expressions.rs index 429013d..ca764f2 100644 --- a/example/derive_expression/expression_lib/src/expressions.rs +++ b/example/derive_expression/expression_lib/src/expressions.rs @@ -4,16 +4,30 @@ use pyo3_polars::derive::polars_expr; use serde::Deserialize; use std::fmt::Write; -fn pig_latin_str(value: &str, output: &mut String) { +#[derive(Deserialize)] +struct PigLatinKwargs { + capitalize: bool, +} + +fn pig_latin_str(value: &str, capitalize: bool, output: &mut String) { if let Some(first_char) = value.chars().next() { - write!(output, "{}{}ay", &value[1..], first_char).unwrap() + if capitalize { + for c in value.chars().skip(1).map(|char| char.to_uppercase()) { + write!(output, "{c}").unwrap() + } + write!(output, "AY").unwrap() + } else { + let offset = first_char.len_utf8(); + write!(output, "{}{}ay", &value[offset..], first_char).unwrap() + } } } #[polars_expr(output_type=Utf8)] -fn pig_latinnify(inputs: &[Series]) -> PolarsResult { +fn pig_latinnify(inputs: &[Series], kwargs: PigLatinKwargs) -> PolarsResult { let ca = inputs[0].utf8()?; - let out: Utf8Chunked = ca.apply_to_buffer(pig_latin_str); + let out: Utf8Chunked = + ca.apply_to_buffer(|value, output| pig_latin_str(value, kwargs.capitalize, output)); Ok(out.into_series()) } diff --git a/example/derive_expression/run.py b/example/derive_expression/run.py index 92c094b..78a0cca 100644 --- a/example/derive_expression/run.py +++ b/example/derive_expression/run.py @@ -16,6 +16,7 @@ out = df.with_columns( pig_latin=pl.col("names").language.pig_latinnify(), + pig_latin_cap=pl.col("names").language.pig_latinnify(capitalize=True), ).with_columns( hamming_dist=pl.col("names").dist.hamming_distance("pig_latin"), jaccard_sim=pl.col("dist_a").dist.jaccard_similarity("dist_b"), @@ -26,7 +27,7 @@ integer_arg=93, boolean_arg=False, string_arg="example", - ) + ), ) print(out) @@ -35,11 +36,12 @@ # Tests we can return errors from FFI by passing wrong types. try: out.with_columns( - appended_args=pl.col("names").language.append_args( - float_arg=True, - integer_arg=True, - boolean_arg=True, - string_arg="example", - )) + appended_args=pl.col("names").language.append_args( + float_arg=True, + integer_arg=True, + boolean_arg=True, + string_arg="example", + ) + ) except pl.ComputeError as e: - assert "the plugin failed with message" in str(e) \ No newline at end of file + assert "the plugin failed with message" in str(e)