Skip to content

Commit

Permalink
feat: add numeric functions (letsql#80)
Browse files Browse the repository at this point in the history
- bitwise_not
- clip
- to_interval (from integer)
  • Loading branch information
mesejo committed Jun 10, 2024
1 parent ed0eb5b commit 5d705a0
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 4 deletions.
22 changes: 19 additions & 3 deletions python/letsql/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,10 @@ class DataFusionCompiler(SQLGlotCompiler):
ops.ArrayFilter,
ops.ArrayMap,
ops.ArrayZip,
ops.BitwiseNot,
ops.Clip,
ops.CountDistinctStar,
ops.DateDelta,
ops.Greatest,
ops.GroupConcat,
ops.IntervalFromInteger,
ops.Least,
ops.MultiQuantile,
ops.Quantile,
Expand Down Expand Up @@ -499,3 +496,22 @@ def visit_HexDigest(self, op, *, arg, how):

def visit_TypeOf(self, op, *, arg):
return self.f.arrow_typeof(arg)

def visit_BitwiseNot(self, op, *, arg):
# https://stackoverflow.com/q/69648488/4001592
return sge.BitwiseXor(this=arg, expression=sg.exp.convert(-1))

def visit_Clip(self, op, *, arg, lower, upper):
ifs = []
if lower is not None:
lower_case = self.if_(arg < lower, lower)
ifs.append(lower_case)
if upper is not None:
upper_case = self.if_(arg > upper, upper)
ifs.append(upper_case)

return sg.exp.Case(ifs=ifs, default=arg)

def visit_IntervalFromInteger(self, op, *, arg, unit):
unit = unit.name.lower()
return sg.cast(self.f.concat(self.cast(arg, dt.string), f" {unit}"), "interval")
55 changes: 55 additions & 0 deletions python/letsql/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,58 @@ def test_bitwise_scalars(con, op, left, right):
result = con.execute(expr)
expected = op(4, 2)
assert result == expected


def test_bitwise_not_scalar(con):
expr = ~L(2)
result = con.execute(expr)
expected = -3
assert result == expected


def test_bitwise_not_col(alltypes, df):
expr = (~alltypes.int_col).name("tmp")
result = expr.execute()
expected = ~df.int_col
assert_series_equal(result, expected.rename("tmp"))


@pytest.mark.parametrize(
("ibis_func", "pandas_func"),
[
param(lambda x: x.clip(lower=0), lambda x: x.clip(lower=0), id="lower-int"),
param(
lambda x: x.clip(lower=0.0), lambda x: x.clip(lower=0.0), id="lower-float"
),
param(lambda x: x.clip(upper=0), lambda x: x.clip(upper=0), id="upper-int"),
param(
lambda x: x.clip(lower=x - 1, upper=x + 1),
lambda x: x.clip(lower=x - 1, upper=x + 1),
id="lower-upper-expr",
),
param(
lambda x: x.clip(lower=0, upper=1),
lambda x: x.clip(lower=0, upper=1),
id="lower-upper-int",
),
param(
lambda x: x.clip(lower=0, upper=1.0),
lambda x: x.clip(lower=0, upper=1.0),
id="lower-upper-float",
),
param(
lambda x: x.nullif(1).clip(lower=0),
lambda x: x.where(x != 1).clip(lower=0),
id="null-lower",
),
param(
lambda x: x.nullif(1).clip(upper=0),
lambda x: x.where(x != 1).clip(upper=0),
id="null-upper",
),
],
)
def test_clip(alltypes, df, ibis_func, pandas_func):
result = ibis_func(alltypes.int_col).execute()
expected = pandas_func(df.int_col).astype(result.dtype)
assert_series_equal(result, expected, check_names=False)
95 changes: 94 additions & 1 deletion python/letsql/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime
import operator
import warnings
from operator import methodcaller

import ibis
Expand All @@ -11,7 +12,11 @@
import pytest
from pytest import param

from letsql.tests.util import assert_frame_equal, assert_series_equal
from letsql.tests.util import (
assert_frame_equal,
assert_series_equal,
default_series_rename,
)


@pytest.mark.parametrize("attr", ["year", "month", "day"])
Expand Down Expand Up @@ -516,3 +521,91 @@ def test_now_from_projection(alltypes):
assert len(result) == n
assert ts.nunique() == 1
assert not pd.isna(ts.iat[0])


@pytest.mark.parametrize(
"unit",
[
"Y",
"M",
"W",
"D",
],
)
def test_integer_to_interval_date(con, alltypes, df, unit):
interval = alltypes.int_col.to_interval(unit=unit)
array = alltypes.date_string_col.split("/")
month, day, year = array[0], array[1], array[2]
date_col = ibis.literal("-").join(["20" + year, month, day]).cast("date")
expr = (date_col + interval).name("tmp")

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
result = con.execute(expr)

def convert_to_offset(x):
resolution = f"{interval.type().resolution}s"
return pd.offsets.DateOffset(**{resolution: x})

offset = df.int_col.apply(convert_to_offset)
with warnings.catch_warnings():
warnings.simplefilter(
"ignore", category=(UserWarning, pd.errors.PerformanceWarning)
)
expected = (
pd.to_datetime(df.date_string_col)
.add(offset)
.map(lambda ts: ts.normalize().date(), na_action="ignore")
)

expected = default_series_rename(expected)
assert_series_equal(result, expected)


@pytest.mark.parametrize(
("unit", "displacement_type"),
[
param(
"Y",
pd.offsets.DateOffset,
),
param(
"M",
pd.offsets.DateOffset,
),
param(
"W",
pd.offsets.DateOffset,
),
param("D", pd.offsets.DateOffset),
param("h", pd.Timedelta),
param("m", pd.Timedelta),
param("s", pd.Timedelta),
param(
"ms",
pd.Timedelta,
),
param(
"us",
pd.Timedelta,
),
],
)
def test_integer_to_interval_timestamp(con, alltypes, df, unit, displacement_type):
interval = alltypes.int_col.to_interval(unit=unit)
expr = (alltypes.timestamp_col + interval).name("tmp")

def convert_to_offset(offset, displacement_type=displacement_type):
resolution = f"{interval.op().dtype.resolution}s"
return displacement_type(**{resolution: offset})

with warnings.catch_warnings():
# both the implementation and test code raises pandas
# PerformanceWarning, because We use DateOffset addition
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
result = con.execute(expr)
offset = df.int_col.apply(convert_to_offset)
expected = df.timestamp_col + offset

expected = default_series_rename(expected)
assert_series_equal(result, expected.astype(result.dtype))
4 changes: 4 additions & 0 deletions python/letsql/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ def assert_frame_equal(
right = right.reset_index(drop=True)
kwargs.setdefault("check_dtype", True)
tm.assert_frame_equal(left, right, *args, **kwargs)


def default_series_rename(series: pd.Series, name: str = "tmp") -> pd.Series:
return series.rename(name)

0 comments on commit 5d705a0

Please sign in to comment.