Skip to content

Commit

Permalink
Add Series.lengths/1 and Series.member?/2 (#746)
Browse files Browse the repository at this point in the history
  • Loading branch information
costaraphael authored Nov 28, 2023
1 parent 5694e79 commit acd6756
Show file tree
Hide file tree
Showing 12 changed files with 361 additions and 34 deletions.
18 changes: 17 additions & 1 deletion lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ defmodule Explorer.Backend.LazySeries do
minute: 1,
second: 1,
# List functions
join: 2
join: 2,
lengths: 1,
member: 3
]

@comparison_operations [:equal, :not_equal, :greater, :greater_equal, :less, :less_equal]
Expand Down Expand Up @@ -990,6 +992,20 @@ defmodule Explorer.Backend.LazySeries do
Backend.Series.new(data, :string)
end

@impl true
def lengths(series) do
data = new(:lengths, [lazy_series!(series)], :integer)

Backend.Series.new(data, :integer)
end

@impl true
def member?(%Series{dtype: {:list, inner_dtype}} = series, value) do
data = new(:member, [lazy_series!(series), value, inner_dtype], :boolean)

Backend.Series.new(data, :boolean)
end

@remaining_non_lazy_operations [
at: 2,
at_every: 2,
Expand Down
2 changes: 2 additions & 0 deletions lib/explorer/backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ defmodule Explorer.Backend.Series do

# List
@callback join(s, String.t()) :: s
@callback lengths(s) :: s
@callback member?(s, valid_types()) :: s

# Functions

Expand Down
4 changes: 3 additions & 1 deletion lib/explorer/polars_backend/expression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ defmodule Explorer.PolarsBackend.Expression do
split: 2,

# Lists
join: 2
join: 2,
lengths: 1,
member: 3
]

@custom_expressions [
Expand Down
2 changes: 2 additions & 0 deletions lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ defmodule Explorer.PolarsBackend.Native do
def s_atan(_s), do: err()

def s_join(_s, _separator), do: err()
def s_lengths(_s), do: err()
def s_member(_s, _value, _inner_dtype), do: err()

defp err, do: :erlang.nif_error(:nif_not_loaded)
end
8 changes: 8 additions & 0 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,14 @@ defmodule Explorer.PolarsBackend.Series do
def join(series, separator),
do: Shared.apply_series(series, :s_join, [separator])

@impl true
def lengths(series),
do: Shared.apply_series(series, :s_lengths)

@impl true
def member?(%Series{dtype: {:list, inner_dtype}} = series, value),
do: Shared.apply_series(series, :s_member, [value, inner_dtype])

# Polars specific functions

def name(series), do: Shared.apply_series(series, :s_name)
Expand Down
42 changes: 42 additions & 0 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5431,6 +5431,48 @@ defmodule Explorer.Series do
def join(%Series{dtype: dtype}, _separator),
do: dtype_error("join/2", dtype, [{:list, :string}])

@doc """
Calculates the length of each list in a list series.
## Examples
iex> s = Series.from_list([[1], [1, 2]])
iex> Series.lengths(s)
#Explorer.Series<
Polars[2]
integer [1, 2]
>
"""
@doc type: :list_wise
@spec lengths(Series.t()) :: Series.t()
def lengths(%Series{dtype: {:list, _}} = series),
do: apply_series(series, :lengths)

def lengths(%Series{dtype: dtype}),
do: dtype_error("lengths/1", dtype, [{:list, :_}])

@doc """
Checks for the presence of a value in a list series.
## Examples
iex> s = Series.from_list([[1], [1, 2]])
iex> Series.member?(s, 2)
#Explorer.Series<
Polars[2]
boolean [false, true]
>
"""
@doc type: :list_wise
@spec member?(Series.t(), Explorer.Backend.Series.valid_types()) :: Series.t()
def member?(%Series{dtype: {:list, _}} = series, value),
do: apply_series(series, :member?, [value])

def member?(%Series{dtype: dtype}, _value),
do: dtype_error("member?/2", dtype, [{:list, :_}])

# Escape hatch

@doc """
Expand Down
113 changes: 113 additions & 0 deletions native/explorer/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,12 @@ impl From<NaiveDate> for ExDate {
}
}

impl Literal for ExDate {
fn lit(self) -> Expr {
NaiveDate::from(self).lit().dt().date()
}
}

#[derive(NifStruct, Copy, Clone, Debug)]
#[module = "Explorer.Duration"]
pub struct ExDuration {
Expand All @@ -226,6 +232,30 @@ impl From<ExDuration> for i64 {
}
}

impl Literal for ExDuration {
fn lit(self) -> Expr {
// Note: it's tempting to use `.lit()` on a `chrono::Duration` struct in this function, but
// doing so will lose precision information as `chrono::Duration`s have no time units.
Expr::Literal(LiteralValue::Duration(
self.value,
time_unit_of_ex_duration(&self),
))
}
}

fn time_unit_of_ex_duration(duration: &ExDuration) -> TimeUnit {
let precision = duration.precision;
if precision == atoms::millisecond() {
TimeUnit::Milliseconds
} else if precision == atoms::microsecond() {
TimeUnit::Microseconds
} else if precision == atoms::nanosecond() {
TimeUnit::Nanoseconds
} else {
panic!("unrecognized precision: {precision:?}")
}
}

#[derive(NifStruct, Copy, Clone, Debug)]
#[module = "NaiveDateTime"]
pub struct ExDateTime {
Expand Down Expand Up @@ -318,6 +348,12 @@ impl From<NaiveDateTime> for ExDateTime {
}
}

impl Literal for ExDateTime {
fn lit(self) -> Expr {
NaiveDateTime::from(self).lit()
}
}

#[derive(NifStruct, Copy, Clone, Debug)]
#[module = "Time"]
pub struct ExTime {
Expand Down Expand Up @@ -379,6 +415,83 @@ impl From<NaiveTime> for ExTime {
}
}

impl Literal for ExTime {
fn lit(self) -> Expr {
Expr::Literal(LiteralValue::Time(self.into()))
}
}

/// Represents valid Elixir types that can be used as literals in Polars.
pub enum ExValidValue<'a> {
I64(i64),
F64(f64),
Bool(bool),
Str(&'a str),
Date(ExDate),
Time(ExTime),
DateTime(ExDateTime),
Duration(ExDuration),
}

impl<'a> ExValidValue<'a> {
pub fn lit_with_matching_precision(self, data_type: &DataType) -> Expr {
match data_type {
DataType::Datetime(time_unit, _) => self.lit().dt().cast_time_unit(*time_unit),
DataType::Duration(time_unit) => self.lit().dt().cast_time_unit(*time_unit),
_ => self.lit(),
}
}
}

impl<'a> Literal for &ExValidValue<'a> {
fn lit(self) -> Expr {
match self {
ExValidValue::I64(v) => v.lit(),
ExValidValue::F64(v) => v.lit(),
ExValidValue::Bool(v) => v.lit(),
ExValidValue::Str(v) => v.lit(),
ExValidValue::Date(v) => v.lit(),
ExValidValue::Time(v) => v.lit(),
ExValidValue::DateTime(v) => v.lit(),
ExValidValue::Duration(v) => v.lit(),
}
}
}

impl<'a> rustler::Decoder<'a> for ExValidValue<'a> {
fn decode(term: rustler::Term<'a>) -> rustler::NifResult<Self> {
use rustler::*;

match term.get_type() {
TermType::Atom => term.decode::<bool>().map(ExValidValue::Bool),
TermType::Binary => term.decode::<&'a str>().map(ExValidValue::Str),
TermType::Number => {
if let Ok(i) = term.decode::<i64>() {
Ok(ExValidValue::I64(i))
} else if let Ok(f) = term.decode::<f64>() {
Ok(ExValidValue::F64(f))
} else {
Err(rustler::Error::BadArg)
}
}
TermType::Map => {
if let Ok(date) = term.decode::<ExDate>() {
Ok(ExValidValue::Date(date))
} else if let Ok(time) = term.decode::<ExTime>() {
Ok(ExValidValue::Time(time))
} else if let Ok(datetime) = term.decode::<ExDateTime>() {
Ok(ExValidValue::DateTime(datetime))
} else if let Ok(duration) = term.decode::<ExDuration>() {
Ok(ExValidValue::Duration(duration))
} else {
Err(rustler::Error::BadArg)
}
}
_ => Err(rustler::Error::BadArg),
}
}
}

// In Elixir this would be represented like this:
// * `:uncompressed` for `ExParquetCompression::Uncompressed`
// * `{:brotli, 7}` for `ExParquetCompression::Brotli(Some(7))`
Expand Down
63 changes: 33 additions & 30 deletions native/explorer/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
// or an expression and returns an expression that is
// wrapped in an Elixir struct.

use chrono::{NaiveDate, NaiveDateTime};
use polars::lazy::dsl::{col, concat_str, cov, pearson_corr, when, Expr, StrptimeOptions};
use polars::prelude::{DataType, Literal, TimeUnit};
use polars::prelude::{IntoLazy, LiteralValue, SortOptions};
use polars::prelude::{
col, concat_str, cov, pearson_corr, when, IntoLazy, LiteralValue, SortOptions,
};
use polars::prelude::{DataType, Expr, Literal, StrptimeOptions, TimeUnit};

use crate::atoms::{microsecond, millisecond, nanosecond};
use crate::datatypes::{ExDate, ExDateTime, ExDuration, ExSeriesDtype};
use crate::datatypes::{ExDate, ExDateTime, ExDuration, ExSeriesDtype, ExValidValue};
use crate::series::{cast_str_to_f64, ewm_opts, rolling_opts};
use crate::{ExDataFrame, ExExpr, ExSeries};

Expand Down Expand Up @@ -54,38 +53,17 @@ pub fn expr_atom(atom: &str) -> ExExpr {

#[rustler::nif]
pub fn expr_date(date: ExDate) -> ExExpr {
let naive_date = NaiveDate::from(date);
let expr = naive_date.lit().dt().date();
ExExpr::new(expr)
ExExpr::new(date.lit())
}

#[rustler::nif]
pub fn expr_datetime(datetime: ExDateTime) -> ExExpr {
let naive_datetime = NaiveDateTime::from(datetime);
let expr = naive_datetime.lit();
ExExpr::new(expr)
ExExpr::new(datetime.lit())
}

#[rustler::nif]
pub fn expr_duration(duration: ExDuration) -> ExExpr {
// Note: it's tempting to use `.lit()` on a `chrono::Duration` struct in this function, but
// doing so will lose precision information as `chrono::Duration`s have no time units.
let time_unit = time_unit_of_ex_duration(duration);
let expr = Expr::Literal(LiteralValue::Duration(duration.value, time_unit));
ExExpr::new(expr)
}

fn time_unit_of_ex_duration(duration: ExDuration) -> TimeUnit {
let precision = duration.precision;
if precision == millisecond() {
TimeUnit::Milliseconds
} else if precision == microsecond() {
TimeUnit::Microseconds
} else if precision == nanosecond() {
TimeUnit::Nanoseconds
} else {
panic!("unrecognized precision: {precision:?}")
}
ExExpr::new(duration.lit())
}

#[rustler::nif]
Expand Down Expand Up @@ -977,3 +955,28 @@ pub fn expr_second(expr: ExExpr) -> ExExpr {

ExExpr::new(expr.dt().second().cast(DataType::Int64))
}

#[rustler::nif]
pub fn expr_join(expr: ExExpr, sep: String) -> ExExpr {
let expr = expr.clone_inner();

ExExpr::new(expr.list().join(sep.lit()))
}

#[rustler::nif]
pub fn expr_lengths(expr: ExExpr) -> ExExpr {
let expr = expr.clone_inner();

ExExpr::new(expr.list().len().cast(DataType::Int64))
}

#[rustler::nif]
pub fn expr_member(expr: ExExpr, value: ExValidValue, inner_dtype: ExSeriesDtype) -> ExExpr {
let expr = expr.clone_inner();
let inner_dtype = DataType::try_from(&inner_dtype).unwrap();

ExExpr::new(
expr.list()
.contains(value.lit_with_matching_precision(&inner_dtype)),
)
}
6 changes: 6 additions & 0 deletions native/explorer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ rustler::init!(
expr_round,
expr_floor,
expr_ceil,
// list expressions
expr_join,
expr_lengths,
expr_member,
// lazyframe
lf_collect,
lf_describe_plan,
Expand Down Expand Up @@ -446,6 +450,8 @@ rustler::init!(
s_floor,
s_ceil,
s_join,
s_lengths,
s_member,
],
load = on_load
);
Loading

0 comments on commit acd6756

Please sign in to comment.