Skip to content

Commit

Permalink
Additional temporal arithmetic (#696)
Browse files Browse the repository at this point in the history
  • Loading branch information
billylanchantin authored Aug 29, 2023
1 parent bebee66 commit 48415da
Show file tree
Hide file tree
Showing 12 changed files with 627 additions and 144 deletions.
84 changes: 15 additions & 69 deletions lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ defmodule Explorer.Backend.LazySeries do

@comparison_operations [:equal, :not_equal, :greater, :greater_equal, :less, :less_equal]

@arithmetic_operations [:pow, :quotient, :remainder]
@basic_arithmetic_operations [:add, :subtract, :multiply, :divide]
@other_arithmetic_operations [:pow, :quotient, :remainder]

@aggregation_operations [
:sum,
Expand Down Expand Up @@ -192,38 +193,6 @@ defmodule Explorer.Backend.LazySeries do
Backend.Series.new(data, dtype)
end

@impl true
def add(left, right) do
args = [data!(left), data!(right)]
data = new(:add, args, aggregations?(args))
dtype = resolve_numeric_temporal_dtype(:add, left, right)
Backend.Series.new(data, dtype)
end

@impl true
def subtract(left, right) do
args = [data!(left), data!(right)]
data = new(:subtract, args, aggregations?(args))
dtype = resolve_numeric_temporal_dtype(:subtract, left, right)
Backend.Series.new(data, dtype)
end

@impl true
def multiply(left, right) do
args = [data!(left), data!(right)]
data = new(:multiply, args, aggregations?(args))
dtype = resolve_numeric_temporal_dtype(:multiply, left, right)
Backend.Series.new(data, dtype)
end

@impl true
def divide(left, right) do
args = [data!(left), data!(right)]
data = new(:divide, args, aggregations?(args))
dtype = resolve_numeric_temporal_dtype(:divide, left, right)
Backend.Series.new(data, dtype)
end

@impl true
def from_list(list, dtype) when is_list(list) and dtype in @valid_dtypes do
data = new(:from_list, [list, dtype], false)
Expand Down Expand Up @@ -412,7 +381,19 @@ defmodule Explorer.Backend.LazySeries do
end
end

for op <- @arithmetic_operations do
for op <- @basic_arithmetic_operations do
@impl true
def unquote(op)(%Series{} = left, %Series{} = right) do
dtype = Explorer.Shared.cast_to_arithmetic(unquote(op), dtype(left), dtype(right))

args = [data!(left), data!(right)]
data = new(unquote(op), args, aggregations?(args))

Backend.Series.new(data, dtype)
end
end

for op <- @other_arithmetic_operations do
@impl true
def unquote(op)(left, right) do
dtype = resolve_numeric_dtype([left, right])
Expand Down Expand Up @@ -654,41 +635,6 @@ defmodule Explorer.Backend.LazySeries do
defp resolve_numeric_dtype(:window_mean, _items), do: :float
defp resolve_numeric_dtype(_op, items), do: resolve_numeric_dtype(items)

defp resolve_numeric_temporal_dtype(op, %Series{dtype: ldt} = left, %Series{dtype: rdt} = right) do
case {op, ldt, rdt} do
{:add, {:datetime, ltu}, {:duration, rtu}} -> {:datetime, highest_precision(ltu, rtu)}
{:add, {:duration, ltu}, {:datetime, rtu}} -> {:datetime, highest_precision(ltu, rtu)}
{:add, {:duration, ltu}, {:duration, rtu}} -> {:duration, highest_precision(ltu, rtu)}
{:subtract, {:datetime, ltu}, {:datetime, rtu}} -> {:duration, highest_precision(ltu, rtu)}
{:subtract, {:datetime, ltu}, {:duration, rtu}} -> {:datetime, highest_precision(ltu, rtu)}
{:subtract, {:duration, ltu}, {:duration, rtu}} -> {:duration, highest_precision(ltu, rtu)}
{:multiply, :integer, {:duration, tu}} -> {:duration, tu}
{:multiply, {:duration, tu}, :integer} -> {:duration, tu}
{:divide, {:duration, tu}, :integer} -> {:duration, tu}
{:divide, _, {:duration, _}} -> raise("cannot divide by duration")
{:divide, _, _} -> :float
_ -> resolve_numeric_dtype([left, right])
end
end

defp resolve_numeric_temporal_dtype(op, left, right) do
case op do
:divide -> :float
_ -> resolve_numeric_dtype([left, right])
end
end

defp highest_precision(left_timeunit, right_timeunit) do
# Higher precision wins, otherwise information is lost.
case {left_timeunit, right_timeunit} do
{equal, equal} -> equal
{:nanosecond, _} -> :nanosecond
{_, :nanosecond} -> :nanosecond
{:microsecond, _} -> :microsecond
{_, :microsecond} -> :microsecond
end
end

# Returns the inner `data` if it's a lazy series. Otherwise raises an error.
defp lazy_series!(series) do
case series do
Expand Down
5 changes: 5 additions & 0 deletions lib/explorer/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2813,6 +2813,11 @@ defmodule Explorer.DataFrame do

Explorer.Backend.Series.new(lazy_s, {:datetime, :microsecond})

duration = %Explorer.Duration{} ->
lazy_s = LazySeries.new(:to_lazy, [duration])

Explorer.Backend.Series.new(lazy_s, {:datetime, duration.precision})

other ->
raise ArgumentError,
"expecting a lazy series or scalar value, but instead got #{inspect(other)}"
Expand Down
3 changes: 3 additions & 0 deletions lib/explorer/duration.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ defmodule Explorer.Duration do
@enforce_keys [:value, :precision]
defstruct [:value, :precision]

@type precision :: :millisecond | :microsecond | :nanosecond
@type t :: %__MODULE__{value: integer(), precision: precision()}

# Nanosecond constants
@us_ns 1_000
@ms_ns 1_000 * @us_ns
Expand Down
1 change: 1 addition & 0 deletions lib/explorer/polars_backend/expression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ defmodule Explorer.PolarsBackend.Expression do
def to_expr(number) when is_float(number), do: Native.expr_float(number)
def to_expr(%Date{} = date), do: Native.expr_date(date)
def to_expr(%NaiveDateTime{} = datetime), do: Native.expr_datetime(datetime)
def to_expr(%Explorer.Duration{} = duration), do: Native.expr_duration(duration)
def to_expr(%PolarsSeries{} = polars_series), do: Native.expr_series(polars_series)

# Used by Explorer.PolarsBackend.DataFrame
Expand Down
1 change: 1 addition & 0 deletions lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ defmodule Explorer.PolarsBackend.Native do
def expr_boolean(_bool), do: err()
def expr_date(_date), do: err()
def expr_datetime(_datetime), do: err()
def expr_duration(_duration), do: err()
def expr_describe_filter_plan(_df, _expr), do: err()
def expr_float(_number), do: err()
def expr_integer(_number), do: err()
Expand Down
47 changes: 41 additions & 6 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -279,20 +279,55 @@ defmodule Explorer.PolarsBackend.Series do
# Arithmetic

@impl true
def add(left, right),
do: Shared.apply_series(matching_size!(left, right), :s_add, [right.data])
def add(left, right) do
left = matching_size!(left, right)

# `duration + date` is not supported by polars for some reason.
# `date + duration` is, so we're swapping arguments as a work around.
[left, right] =
case {dtype(left), dtype(right)} do
{{:duration, _}, :date} -> [right, left]
_ -> [left, right]
end

Shared.apply_series(left, :s_add, [right.data])
end

@impl true
def subtract(left, right),
do: Shared.apply_series(matching_size!(left, right), :s_subtract, [right.data])

@impl true
def multiply(left, right),
do: Shared.apply_series(matching_size!(left, right), :s_multiply, [right.data])
def multiply(left, right) do
result = Shared.apply_series(matching_size!(left, right), :s_multiply, [right.data])
expected_dtype = Explorer.Shared.cast_to_arithmetic(:multiply, dtype(left), dtype(right))

# Polars currently returns inconsistent dtypes, e.g.:
# * `integer * duration -> duration` when `integer` is a scalar
# * `integer * duration -> integer` when `integer` is a series
# We need to return duration in these cases, so we need an additional cast.
if match?({:duration, _}, expected_dtype) and expected_dtype != dtype(result) do
cast(result, expected_dtype)
else
result
end
end

@impl true
def divide(left, right),
do: Shared.apply_series(matching_size!(left, right), :s_divide, [right.data])
def divide(left, right) do
result = Shared.apply_series(matching_size!(left, right), :s_divide, [right.data])
expected_dtype = Explorer.Shared.cast_to_arithmetic(:divide, dtype(left), dtype(right))

# Polars currently returns inconsistent dtypes, e.g.:
# * `duration / integer -> duration` when `integer` is a scalar
# * `duration / integer -> integer` when `integer` is a series
# We need to return duration in these cases, so we need an additional cast.
if match?({:duration, _}, expected_dtype) and expected_dtype != dtype(result) do
cast(result, expected_dtype)
else
result
end
end

@impl true
def quotient(left, right),
Expand Down
Loading

0 comments on commit 48415da

Please sign in to comment.