Skip to content

Commit 7d5e384

Browse files
authored
feat: expression plugins (#26)
1 parent a46d98c commit 7d5e384

File tree

36 files changed

+624
-89
lines changed

36 files changed

+624
-89
lines changed

.github/workflows/CI.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
working-directory: pyo3-polars
3838

3939
- run: make install
40-
working-directory: example
40+
working-directory: example/extend_polars_python_dispatch
4141

4242
- run: venv/bin/python run.py
43-
working-directory: example
43+
working-directory: example/extend_polars_python_dispatch

Cargo.toml

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[workspace]
2+
resolver = "2"
3+
members = [
4+
"example/derive_expression/expression_lib",
5+
"example/extend_polars_python_dispatch/extend_polars",
6+
"pyo3-polars",
7+
"pyo3-polars-derive",
8+
]
9+
10+
[workspace.dependencies]
11+
polars = {version = "0.33.2", default-features=false}
12+
polars-core = {version = "0.33.2", default-features=false}
13+
polars-ffi = {ersion = "0.33.2", default-features=false}
14+
polars-plan = {version = "0.33.2", default-feautres=false}
15+
polars-lazy = {version = "0.33.2", default-features=false}

README.md

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
# Pyo3 extensions for Polars
1+
## 1. Shared library plugins for Polars
2+
This is new functionality and not entirely stable, but should be preferred over `2.` as this
3+
will circumvent the GIL and will be the way we want to support extending polars.
4+
5+
See more in `examples/derive_expression`.
6+
7+
## 2. Pyo3 extensions for Polars
28
<a href="https://crates.io/crates/pyo3-polars">
39
<img src="https://img.shields.io/crates/v/pyo3-polars.svg"/>
410
</a>

example/derive_expression/Makefile

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
2+
SHELL=/bin/bash
3+
4+
venv: ## Set up virtual environment
5+
python3 -m venv venv
6+
venv/bin/pip install -r requirements.txt
7+
8+
install: venv
9+
unset CONDA_PREFIX && \
10+
source venv/bin/activate && maturin develop -m expression_lib/Cargo.toml
11+
12+
install-release: venv
13+
unset CONDA_PREFIX && \
14+
source venv/bin/activate && maturin develop --release -m expression_lib/Cargo.toml
15+
16+
clean:
17+
-@rm -r venv
18+
-@cd experssion_lib && cargo clean
19+
20+
21+
run: install
22+
source venv/bin/activate && python run.py
23+
24+
run-release: install-release
25+
source venv/bin/activate && python run.py
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[package]
2+
name = "expression_lib"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7+
[lib]
8+
name = "expression_lib"
9+
crate-type = ["cdylib"]
10+
11+
[dependencies]
12+
pyo3 = { version = "0.19.0", features = ["extension-module"] }
13+
pyo3-polars = { version = "*", path = "../../../pyo3-polars", features=["derive"] }
14+
polars = { workspace = true, features = ["fmt"], default-features=false }
15+
polars-plan = { workspace = true, default-features=false }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import polars as pl
2+
from polars.type_aliases import IntoExpr
3+
from polars.utils.udfs import _get_shared_lib_location
4+
5+
lib = _get_shared_lib_location(__file__)
6+
7+
8+
@pl.api.register_expr_namespace("language")
9+
class Language:
10+
def __init__(self, expr: pl.Expr):
11+
self._expr = expr
12+
13+
def pig_latinnify(self) -> pl.Expr:
14+
return self._expr._register_plugin(
15+
lib=lib,
16+
symbol="pig_latinnify",
17+
is_elementwise=True,
18+
)
19+
20+
@pl.api.register_expr_namespace("dist")
21+
class Distance:
22+
def __init__(self, expr: pl.Expr):
23+
self._expr = expr
24+
25+
def hamming_distance(self, other: IntoExpr) -> pl.Expr:
26+
return self._expr._register_plugin(
27+
lib=lib,
28+
args=[other],
29+
symbol="hamming_distance",
30+
is_elementwise=True,
31+
)
32+
33+
def jaccard_similarity(self, other: IntoExpr) -> pl.Expr:
34+
return self._expr._register_plugin(
35+
lib=lib,
36+
args=[other],
37+
symbol="jaccard_similarity",
38+
is_elementwise=True,
39+
)
40+
41+
def haversine(self, start_lat: IntoExpr, start_long: IntoExpr, end_lat: IntoExpr, end_long: IntoExpr) -> pl.Expr:
42+
return self._expr._register_plugin(
43+
lib=lib,
44+
args=[start_lat, start_long, end_lat, end_long],
45+
symbol="haversine",
46+
is_elementwise=True,
47+
cast_to_supertypes=True
48+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[build-system]
2+
requires = ["maturin>=1.0,<2.0"]
3+
build-backend = "maturin"
4+
5+
[project]
6+
name = "expression_lib"
7+
requires-python = ">=3.8"
8+
classifiers = [
9+
"Programming Language :: Rust",
10+
"Programming Language :: Python :: Implementation :: CPython",
11+
"Programming Language :: Python :: Implementation :: PyPy",
12+
]
13+
14+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
use polars::datatypes::PlHashSet;
2+
use polars::export::arrow::array::PrimitiveArray;
3+
use polars::export::num::Float;
4+
use polars::prelude::*;
5+
use pyo3_polars::export::polars_core::utils::arrow::types::NativeType;
6+
use pyo3_polars::export::polars_core::with_match_physical_integer_type;
7+
use std::hash::Hash;
8+
9+
#[allow(clippy::all)]
10+
pub(super) fn naive_hamming_dist(a: &str, b: &str) -> u32 {
11+
let x = a.as_bytes();
12+
let y = b.as_bytes();
13+
x.iter()
14+
.zip(y)
15+
.fold(0, |a, (b, c)| a + (*b ^ *c).count_ones() as u32)
16+
}
17+
18+
fn jacc_helper<T: NativeType + Hash + Eq>(a: &PrimitiveArray<T>, b: &PrimitiveArray<T>) -> f64 {
19+
// convert to hashsets over Option<T>
20+
let s1 = a.into_iter().collect::<PlHashSet<_>>();
21+
let s2 = b.into_iter().collect::<PlHashSet<_>>();
22+
23+
// count the number of intersections
24+
let s3_len = s1.intersection(&s2).count();
25+
// return similarity
26+
s3_len as f64 / (s1.len() + s2.len() - s3_len) as f64
27+
}
28+
29+
pub(super) fn naive_jaccard_sim(a: &ListChunked, b: &ListChunked) -> PolarsResult<Float64Chunked> {
30+
polars_ensure!(
31+
a.inner_dtype() == b.inner_dtype(),
32+
ComputeError: "inner data types don't match"
33+
);
34+
polars_ensure!(
35+
a.inner_dtype().is_integer(),
36+
ComputeError: "inner data types must be integer"
37+
);
38+
Ok(with_match_physical_integer_type!(a.inner_dtype(), |$T| {
39+
polars::prelude::arity::binary_elementwise(a, b, |a, b| {
40+
match (a, b) {
41+
(Some(a), Some(b)) => {
42+
let a = a.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
43+
let b = b.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
44+
Some(jacc_helper(a, b))
45+
},
46+
_ => None
47+
}
48+
})
49+
}))
50+
}
51+
52+
fn haversine_elementwise<T: Float>(start_lat: T, start_long: T, end_lat: T, end_long: T) -> T {
53+
let r_in_km = T::from(6371.0).unwrap();
54+
let two = T::from(2.0).unwrap();
55+
let one = T::one();
56+
57+
let d_lat = (end_lat - start_lat).to_radians();
58+
let d_lon = (end_long - start_long).to_radians();
59+
let lat1 = (start_lat).to_radians();
60+
let lat2 = (end_lat).to_radians();
61+
62+
let a = ((d_lat / two).sin()) * ((d_lat / two).sin())
63+
+ ((d_lon / two).sin()) * ((d_lon / two).sin()) * (lat1.cos()) * (lat2.cos());
64+
let c = two * ((a.sqrt()).atan2((one - a).sqrt()));
65+
r_in_km * c
66+
}
67+
68+
pub(super) fn naive_haversine<T>(
69+
start_lat: &ChunkedArray<T>,
70+
start_long: &ChunkedArray<T>,
71+
end_lat: &ChunkedArray<T>,
72+
end_long: &ChunkedArray<T>,
73+
) -> PolarsResult<ChunkedArray<T>>
74+
where
75+
T: PolarsFloatType,
76+
T::Native: Float,
77+
{
78+
let out: ChunkedArray<T> = start_lat
79+
.into_iter()
80+
.zip(start_long.into_iter())
81+
.zip(end_lat.into_iter())
82+
.zip(end_long.into_iter())
83+
.map(|(((start_lat, start_long), end_lat), end_long)| {
84+
let start_lat = start_lat?;
85+
let start_long = start_long?;
86+
let end_lat = end_lat?;
87+
let end_long = end_long?;
88+
Some(haversine_elementwise(
89+
start_lat, start_long, end_lat, end_long,
90+
))
91+
})
92+
.collect();
93+
94+
Ok(out.with_name(start_lat.name()))
95+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
use polars::prelude::*;
2+
use polars_plan::dsl::FieldsMapper;
3+
use pyo3_polars::derive::polars_expr;
4+
use std::fmt::Write;
5+
6+
fn pig_latin_str(value: &str, output: &mut String) {
7+
if let Some(first_char) = value.chars().next() {
8+
write!(output, "{}{}ay", &value[1..], first_char).unwrap()
9+
}
10+
}
11+
12+
#[polars_expr(output_type=Utf8)]
13+
fn pig_latinnify(inputs: &[Series]) -> PolarsResult<Series> {
14+
let ca = inputs[0].utf8()?;
15+
let out: Utf8Chunked = ca.apply_to_buffer(pig_latin_str);
16+
Ok(out.into_series())
17+
}
18+
19+
#[polars_expr(output_type=Float64)]
20+
fn jaccard_similarity(inputs: &[Series]) -> PolarsResult<Series> {
21+
let a = inputs[0].list()?;
22+
let b = inputs[1].list()?;
23+
crate::distances::naive_jaccard_sim(a, b).map(|ca| ca.into_series())
24+
}
25+
26+
#[polars_expr(output_type=Float64)]
27+
fn hamming_distance(inputs: &[Series]) -> PolarsResult<Series> {
28+
let a = inputs[0].utf8()?;
29+
let b = inputs[1].utf8()?;
30+
let out: UInt32Chunked =
31+
arity::binary_elementwise_values(a, b, crate::distances::naive_hamming_dist);
32+
Ok(out.into_series())
33+
}
34+
35+
fn haversine_output(input_fields: &[Field]) -> PolarsResult<Field> {
36+
FieldsMapper::new(input_fields).map_to_float_dtype()
37+
}
38+
39+
#[polars_expr(type_func=haversine_output)]
40+
fn haversine(inputs: &[Series]) -> PolarsResult<Series> {
41+
let out = match inputs[0].dtype() {
42+
DataType::Float32 => {
43+
let start_lat = inputs[0].f32().unwrap();
44+
let start_long = inputs[1].f32().unwrap();
45+
let end_lat = inputs[2].f32().unwrap();
46+
let end_long = inputs[3].f32().unwrap();
47+
crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)?
48+
.into_series()
49+
}
50+
DataType::Float64 => {
51+
let start_lat = inputs[0].f64().unwrap();
52+
let start_long = inputs[1].f64().unwrap();
53+
let end_lat = inputs[2].f64().unwrap();
54+
let end_long = inputs[3].f64().unwrap();
55+
crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)?
56+
.into_series()
57+
}
58+
_ => unimplemented!(),
59+
};
60+
Ok(out)
61+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
mod distances;
2+
mod expressions;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
maturin

example/derive_expression/run.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import polars as pl
2+
from expression_lib import Language, Distance
3+
4+
df = pl.DataFrame({
5+
"names": ["Richard", "Alice", "Bob"],
6+
"moons": ["full", "half", "red"],
7+
"dist_a": [[12, 32, 1], [], [1, -2]],
8+
"dist_b": [[-12, 1], [43], [876, -45, 9]]
9+
})
10+
11+
12+
out = df.with_columns(
13+
pig_latin = pl.col("names").language.pig_latinnify()
14+
).with_columns(
15+
hamming_dist = pl.col("names").dist.hamming_distance("pig_latin"),
16+
jaccard_sim = pl.col("dist_a").dist.jaccard_similarity("dist_b")
17+
)
18+
19+
print(out)

example/extend_polars/.github/workflows/CI.yml

-70
This file was deleted.

0 commit comments

Comments
 (0)