Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Series.length/1 and Series.member?/2 #746

Merged
merged 3 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ defmodule Explorer.Backend.LazySeries do
minute: 1,
second: 1,
# List functions
join: 2
join: 2,
lengths: 1,
member: 3
]

@comparison_operations [:equal, :not_equal, :greater, :greater_equal, :less, :less_equal]
Expand Down Expand Up @@ -990,6 +992,20 @@ defmodule Explorer.Backend.LazySeries do
Backend.Series.new(data, :string)
end

@impl true
def lengths(series) do
data = new(:lengths, [lazy_series!(series)], :integer)

Backend.Series.new(data, :integer)
end

@impl true
def member?(%Series{dtype: {:list, inner_dtype}} = series, value) do
data = new(:member, [lazy_series!(series), value, inner_dtype], :boolean)

Backend.Series.new(data, :boolean)
end

@remaining_non_lazy_operations [
at: 2,
at_every: 2,
Expand Down
2 changes: 2 additions & 0 deletions lib/explorer/backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ defmodule Explorer.Backend.Series do

# List
@callback join(s, String.t()) :: s
@callback lengths(s) :: s
@callback member?(s, valid_types()) :: s

# Functions

Expand Down
4 changes: 3 additions & 1 deletion lib/explorer/polars_backend/expression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ defmodule Explorer.PolarsBackend.Expression do
split: 2,

# Lists
join: 2
join: 2,
lengths: 1,
member: 3
]

@custom_expressions [
Expand Down
2 changes: 2 additions & 0 deletions lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ defmodule Explorer.PolarsBackend.Native do
def s_atan(_s), do: err()

def s_join(_s, _separator), do: err()
def s_lengths(_s), do: err()
def s_member(_s, _value, _inner_dtype), do: err()

defp err, do: :erlang.nif_error(:nif_not_loaded)
end
8 changes: 8 additions & 0 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,14 @@ defmodule Explorer.PolarsBackend.Series do
def join(series, separator),
do: Shared.apply_series(series, :s_join, [separator])

@impl true
def lengths(series),
do: Shared.apply_series(series, :s_lengths)

@impl true
def member?(%Series{dtype: {:list, inner_dtype}} = series, value),
do: Shared.apply_series(series, :s_member, [value, inner_dtype])

# Polars specific functions

def name(series), do: Shared.apply_series(series, :s_name)
Expand Down
42 changes: 42 additions & 0 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5414,6 +5414,48 @@ defmodule Explorer.Series do
def join(%Series{dtype: dtype}, _separator),
do: dtype_error("join/2", dtype, [{:list, :string}])

@doc """
Calculates the length of each list in a list series.

## Examples

iex> s = Series.from_list([[1], [1, 2]])
iex> Series.lengths(s)
#Explorer.Series<
Polars[2]
integer [1, 2]
>

"""
@doc type: :list_wise
@spec lengths(Series.t()) :: Series.t()
def lengths(%Series{dtype: {:list, _}} = series),
do: apply_series(series, :lengths)

def lengths(%Series{dtype: dtype}),
do: dtype_error("lengths/1", dtype, [{:list, :_}])

@doc """
Checks for the presence of a value in a list series.

## Examples

iex> s = Series.from_list([[1], [1, 2]])
iex> Series.member?(s, 2)
#Explorer.Series<
Polars[2]
boolean [false, true]
>

"""
@doc type: :list_wise
@spec member?(Series.t(), Explorer.Backend.Series.valid_types()) :: Series.t()
def member?(%Series{dtype: {:list, _}} = series, value),
do: apply_series(series, :member?, [value])

def member?(%Series{dtype: dtype}, _value),
do: dtype_error("member?/2", dtype, [{:list, :_}])

# Escape hatch

@doc """
Expand Down
113 changes: 113 additions & 0 deletions native/explorer/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,12 @@ impl From<NaiveDate> for ExDate {
}
}

impl Literal for ExDate {
fn lit(self) -> Expr {
NaiveDate::from(self).lit().dt().date()
}
}

#[derive(NifStruct, Copy, Clone, Debug)]
#[module = "Explorer.Duration"]
pub struct ExDuration {
Expand All @@ -226,6 +232,30 @@ impl From<ExDuration> for i64 {
}
}

impl Literal for ExDuration {
fn lit(self) -> Expr {
// Note: it's tempting to use `.lit()` on a `chrono::Duration` struct in this function, but
// doing so will lose precision information as `chrono::Duration`s have no time units.
Expr::Literal(LiteralValue::Duration(
self.value,
time_unit_of_ex_duration(&self),
))
}
}

fn time_unit_of_ex_duration(duration: &ExDuration) -> TimeUnit {
let precision = duration.precision;
if precision == atoms::millisecond() {
TimeUnit::Milliseconds
} else if precision == atoms::microsecond() {
TimeUnit::Microseconds
} else if precision == atoms::nanosecond() {
TimeUnit::Nanoseconds
} else {
panic!("unrecognized precision: {precision:?}")
}
}

#[derive(NifStruct, Copy, Clone, Debug)]
#[module = "NaiveDateTime"]
pub struct ExDateTime {
Expand Down Expand Up @@ -318,6 +348,12 @@ impl From<NaiveDateTime> for ExDateTime {
}
}

impl Literal for ExDateTime {
fn lit(self) -> Expr {
NaiveDateTime::from(self).lit()
}
}

#[derive(NifStruct, Copy, Clone, Debug)]
#[module = "Time"]
pub struct ExTime {
Expand Down Expand Up @@ -379,6 +415,83 @@ impl From<NaiveTime> for ExTime {
}
}

impl Literal for ExTime {
fn lit(self) -> Expr {
Expr::Literal(LiteralValue::Time(self.into()))
}
}

/// Represents valid Elixir types that can be used as literals in Polars.
pub enum ExValidValue<'a> {
I64(i64),
F64(f64),
Bool(bool),
Str(&'a str),
Date(ExDate),
Time(ExTime),
DateTime(ExDateTime),
Duration(ExDuration),
}

impl<'a> ExValidValue<'a> {
pub fn lit_with_matching_precision(self, data_type: &DataType) -> Expr {
match data_type {
DataType::Datetime(time_unit, _) => self.lit().dt().cast_time_unit(*time_unit),
DataType::Duration(time_unit) => self.lit().dt().cast_time_unit(*time_unit),
_ => self.lit(),
}
}
}

impl<'a> Literal for &ExValidValue<'a> {
fn lit(self) -> Expr {
match self {
ExValidValue::I64(v) => v.lit(),
ExValidValue::F64(v) => v.lit(),
ExValidValue::Bool(v) => v.lit(),
ExValidValue::Str(v) => v.lit(),
ExValidValue::Date(v) => v.lit(),
ExValidValue::Time(v) => v.lit(),
ExValidValue::DateTime(v) => v.lit(),
ExValidValue::Duration(v) => v.lit(),
}
}
}

impl<'a> rustler::Decoder<'a> for ExValidValue<'a> {
fn decode(term: rustler::Term<'a>) -> rustler::NifResult<Self> {
use rustler::*;

match term.get_type() {
TermType::Atom => term.decode::<bool>().map(ExValidValue::Bool),
TermType::Binary => term.decode::<&'a str>().map(ExValidValue::Str),
TermType::Number => {
if let Ok(i) = term.decode::<i64>() {
Ok(ExValidValue::I64(i))
} else if let Ok(f) = term.decode::<f64>() {
Ok(ExValidValue::F64(f))
} else {
Err(rustler::Error::BadArg)
}
}
TermType::Map => {
if let Ok(date) = term.decode::<ExDate>() {
Ok(ExValidValue::Date(date))
} else if let Ok(time) = term.decode::<ExTime>() {
Ok(ExValidValue::Time(time))
} else if let Ok(datetime) = term.decode::<ExDateTime>() {
Ok(ExValidValue::DateTime(datetime))
} else if let Ok(duration) = term.decode::<ExDuration>() {
Ok(ExValidValue::Duration(duration))
} else {
Err(rustler::Error::BadArg)
}
}
_ => Err(rustler::Error::BadArg),
}
}
}

// In Elixir this would be represented like this:
// * `:uncompressed` for `ExParquetCompression::Uncompressed`
// * `{:brotli, 7}` for `ExParquetCompression::Brotli(Some(7))`
Expand Down
63 changes: 33 additions & 30 deletions native/explorer/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
// or an expression and returns an expression that is
// wrapped in an Elixir struct.

use chrono::{NaiveDate, NaiveDateTime};
use polars::lazy::dsl::{col, concat_str, cov, pearson_corr, when, Expr, StrptimeOptions};
use polars::prelude::{DataType, Literal, TimeUnit};
use polars::prelude::{IntoLazy, LiteralValue, SortOptions};
use polars::prelude::{
col, concat_str, cov, pearson_corr, when, IntoLazy, LiteralValue, SortOptions,
};
use polars::prelude::{DataType, Expr, Literal, StrptimeOptions, TimeUnit};

use crate::atoms::{microsecond, millisecond, nanosecond};
use crate::datatypes::{ExDate, ExDateTime, ExDuration, ExSeriesDtype};
use crate::datatypes::{ExDate, ExDateTime, ExDuration, ExSeriesDtype, ExValidValue};
use crate::series::{cast_str_to_f64, ewm_opts, rolling_opts};
use crate::{ExDataFrame, ExExpr, ExSeries};

Expand Down Expand Up @@ -54,38 +53,17 @@ pub fn expr_atom(atom: &str) -> ExExpr {

#[rustler::nif]
pub fn expr_date(date: ExDate) -> ExExpr {
let naive_date = NaiveDate::from(date);
let expr = naive_date.lit().dt().date();
ExExpr::new(expr)
ExExpr::new(date.lit())
}

#[rustler::nif]
pub fn expr_datetime(datetime: ExDateTime) -> ExExpr {
let naive_datetime = NaiveDateTime::from(datetime);
let expr = naive_datetime.lit();
ExExpr::new(expr)
ExExpr::new(datetime.lit())
}

#[rustler::nif]
pub fn expr_duration(duration: ExDuration) -> ExExpr {
// Note: it's tempting to use `.lit()` on a `chrono::Duration` struct in this function, but
// doing so will lose precision information as `chrono::Duration`s have no time units.
let time_unit = time_unit_of_ex_duration(duration);
let expr = Expr::Literal(LiteralValue::Duration(duration.value, time_unit));
ExExpr::new(expr)
}

fn time_unit_of_ex_duration(duration: ExDuration) -> TimeUnit {
let precision = duration.precision;
if precision == millisecond() {
TimeUnit::Milliseconds
} else if precision == microsecond() {
TimeUnit::Microseconds
} else if precision == nanosecond() {
TimeUnit::Nanoseconds
} else {
panic!("unrecognized precision: {precision:?}")
}
ExExpr::new(duration.lit())
}

#[rustler::nif]
Expand Down Expand Up @@ -977,3 +955,28 @@ pub fn expr_second(expr: ExExpr) -> ExExpr {

ExExpr::new(expr.dt().second().cast(DataType::Int64))
}

#[rustler::nif]
pub fn expr_join(expr: ExExpr, sep: String) -> ExExpr {
let expr = expr.clone_inner();

ExExpr::new(expr.list().join(sep.lit()))
}

#[rustler::nif]
pub fn expr_lengths(expr: ExExpr) -> ExExpr {
let expr = expr.clone_inner();

ExExpr::new(expr.list().len().cast(DataType::Int64))
}

#[rustler::nif]
pub fn expr_member(expr: ExExpr, value: ExValidValue, inner_dtype: ExSeriesDtype) -> ExExpr {
let expr = expr.clone_inner();
let inner_dtype = DataType::try_from(&inner_dtype).unwrap();

ExExpr::new(
expr.list()
.contains(value.lit_with_matching_precision(&inner_dtype)),
)
}
6 changes: 6 additions & 0 deletions native/explorer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ rustler::init!(
expr_round,
expr_floor,
expr_ceil,
// list expressions
expr_join,
expr_lengths,
expr_member,
// lazyframe
lf_collect,
lf_describe_plan,
Expand Down Expand Up @@ -446,6 +450,8 @@ rustler::init!(
s_floor,
s_ceil,
s_join,
s_lengths,
s_member,
],
load = on_load
);
Loading
Loading