Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): move from .case() to .cases() #9096

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ quartodoc:
- name: ifelse
dynamic: true
signature_name: full
- name: case
- name: cases
dynamic: true
signature_name: full

Expand Down
12 changes: 5 additions & 7 deletions docs/posts/ci-analysis/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,11 @@ Let's also give them some names that'll look nice on our plots.
stats = stats.mutate(
raw_improvements=_.has_poetry.cast("int") + _.has_team.cast("int")
).mutate(
improvements=(
_.raw_improvements.case()
.when(0, "None")
.when(1, "Poetry")
.when(2, "Poetry + Team Plan")
.else_("NA")
.end()
improvements=_.raw_improvements.cases(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the files in this directory have pre-computed outputs, and if you change them, they'll rerun.

In this case, this causes a failure, because it's trying to authenticate to BigQuery (I'm guessing browser-based auth is first in the credential chain, and thus the message about a missing browser).

Even if the job did have permission to render this quarto document, it would still fail, because we verify that for documents that are "frozen" that they aren't being re-rendered unintentionally.

When stuff like this fails, it's not something that should really ever just be merged, hence the required property of the docs build check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so the lesson for me is, if I change something in docs/posts/, I should re-run it locally before pushing. for this bigquery notebook, I won't be able to do this since I don't have creds, but for the rest I should be good to go

(0, "None"),
(1, "Poetry"),
(2, "Poetry + Team Plan"),
else_="NA",
),
team_plan=ibis.ifelse(_.raw_improvements > 1, "Poetry + Team Plan", "None"),
)
Expand Down
24 changes: 11 additions & 13 deletions docs/tutorials/ibis-for-sql-users.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,11 @@ semantics:
case = (
t.one.cast("timestamp")
.year()
.case()
.when(2015, "This year")
.when(2014, "Last year")
.else_("Earlier")
.end()
.cases(
(2015, "This year"),
(2014, "Last year"),
else_="Earlier",
)
)
expr = t.mutate(year_group=case)
Expand All @@ -489,18 +489,16 @@ CASE
END
```

To do this, use `ibis.case`:
To do this, use `ibis.cases`:

```{python}
case = (
ibis.case()
.when(t.two < 0, t.three * 2)
.when(t.two > 1, t.three)
.else_(t.two)
.end()
cases = ibis.cases(
(t.two < 0, t.three * 2),
(t.two > 1, t.three),
else_=t.two,
)
expr = t.mutate(cond_value=case)
expr = t.mutate(cond_value=cases)
ibis.to_sql(expr)
```

Expand Down
14 changes: 5 additions & 9 deletions ibis/backends/clickhouse/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,18 @@ def test_ifelse(alltypes, df, op, pandas_op):

def test_simple_case(con, alltypes, assert_sql):
t = alltypes
expr = (
t.string_col.case().when("foo", "bar").when("baz", "qux").else_("default").end()
)
expr = t.string_col.cases(("foo", "bar"), ("baz", "qux"), else_="default")

assert_sql(expr)
assert len(con.execute(expr))


def test_search_case(con, alltypes, assert_sql):
t = alltypes
expr = (
ibis.case()
.when(t.float_col > 0, t.int_col * 2)
.when(t.float_col < 0, t.int_col)
.else_(0)
.end()
expr = ibis.cases(
(t.float_col > 0, t.int_col * 2),
(t.float_col < 0, t.int_col),
else_=0,
)

assert_sql(expr)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/impala/tests/test_case_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ def table(mockcon):

@pytest.fixture
def simple_case(table):
return table.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
return table.g.cases(("foo", "bar"), ("baz", "qux"), else_="default")


@pytest.fixture
def search_case(table):
t = table
return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end()
return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2))


@pytest.fixture
Expand Down
38 changes: 12 additions & 26 deletions ibis/backends/snowflake/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest
from pytest import param

import ibis
import ibis.expr.datatypes as dt
from ibis import udf

Expand Down Expand Up @@ -122,36 +121,23 @@ def predict_price(
df.columns = ["CARAT_SCALED", "CUT_ENCODED", "COLOR_ENCODED", "CLARITY_ENCODED"]
return model.predict(df)

def cases(value, mapping):
"""This should really be a top-level function or method."""
expr = ibis.case()
for k, v in mapping.items():
expr = expr.when(value == k, v)
return expr.end()

diamonds = con.tables.DIAMONDS
expr = diamonds.mutate(
predicted_price=predict_price(
(_.carat - _.carat.mean()) / _.carat.std(),
cases(
_.cut,
{
c: i
for i, c in enumerate(
("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1
)
},
_.cut.cases(
(c, i)
for i, c in enumerate(
("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1
)
),
cases(_.color, {c: i for i, c in enumerate("DEFGHIJ", start=1)}),
cases(
_.clarity,
{
c: i
for i, c in enumerate(
("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"),
start=1,
)
},
_.color.cases((c, i) for i, c in enumerate("DEFGHIJ", start=1)),
_.clarity.cases(
(c, i)
for i, c in enumerate(
("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"),
start=1,
)
),
)
)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/sql/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ def difference(con):
@pytest.fixture(scope="module")
def simple_case(con):
t = con.table("alltypes")
return t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
return t.g.cases(("foo", "bar"), ("baz", "qux"), else_="default")


@pytest.fixture(scope="module")
def search_case(con):
t = con.table("alltypes")
return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end()
return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2))


@pytest.fixture(scope="module")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,14 @@
lit2 = ibis.literal("bar")

result = alltypes.select(
alltypes.g.case()
.when(lit, lit2)
.when(lit1, ibis.literal("qux"))
.else_(ibis.literal("default"))
.end()
.name("col1"),
ibis.case()
.when((alltypes.g == lit), lit2)
.when((alltypes.g == lit1), alltypes.g)
.else_(ibis.literal(None))
.end()
.name("col2"),
alltypes.g.cases(
(lit, lit2), (lit1, ibis.literal("qux")), else_=ibis.literal("default")
).name("col1"),
ibis.cases(
((alltypes.g == lit), lit2),
((alltypes.g == lit1), alltypes.g),
else_=ibis.literal(None),
).name("col2"),
alltypes.a,
alltypes.b,
alltypes.c,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/sql/test_select_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ def test_bool_bool(snapshot):

def test_case_in_projection(alltypes, snapshot):
t = alltypes
expr = t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
expr2 = ibis.case().when(t.g == "foo", "bar").when(t.g == "baz", t.g).end()
expr = t.g.cases(("foo", "bar"), ("baz", "qux"), else_=("default"))
expr2 = ibis.cases((t.g == "foo", "bar"), (t.g == "baz", t.g))
expr = t.select(expr.name("col1"), expr2.name("col2"), t)

snapshot.assert_match(to_sql(expr), "out.sql")
Expand Down
10 changes: 4 additions & 6 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def test_first_last(alltypes, method, filtered, include_null):
# To sanely test this we create a column that is a mix of nulls and a
# single value (or a single value after filtering is applied).
if filtered:
new = alltypes.int_col.cases([(3, 30), (4, 40)])
new = alltypes.int_col.cases((3, 30), (4, 40))
where = _.int_col == 3
else:
new = (alltypes.int_col == 3).ifelse(30, None)
Expand Down Expand Up @@ -738,7 +738,7 @@ def test_arbitrary(alltypes, filtered):
# _something_ we create a column that is a mix of nulls and a single value
# (or a single value after filtering is applied).
if filtered:
new = alltypes.int_col.cases([(3, 30), (4, 40)])
new = alltypes.int_col.cases((3, 30), (4, 40))
where = _.int_col == 3
else:
new = (alltypes.int_col == 3).ifelse(30, None)
Expand Down Expand Up @@ -1571,9 +1571,7 @@ def collect_udf(v):

def test_binds_are_cast(alltypes):
expr = alltypes.aggregate(
high_line_count=(
alltypes.string_col.case().when("1-URGENT", 1).else_(0).end().sum()
)
high_line_count=alltypes.string_col.cases(("1-URGENT", 1), else_=0).sum()
)

expr.execute()
Expand Down Expand Up @@ -1616,7 +1614,7 @@ def test_agg_name_in_output_column(alltypes):
def test_grouped_case(backend, con):
table = ibis.memtable({"key": [1, 1, 2, 2], "value": [10, 30, 20, 40]})

case_expr = ibis.case().when(table.value < 25, table.value).else_(ibis.null()).end()
case_expr = ibis.cases((table.value < 25, table.value), else_=ibis.null())

expr = (
table.group_by(k="key")
Expand Down
75 changes: 51 additions & 24 deletions ibis/backends/tests/test_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import Counter

import pytest
from pytest import param

import ibis

Expand Down Expand Up @@ -62,18 +63,13 @@
@pytest.mark.parametrize(
"inp, exp",
[
pytest.param(
lambda: ibis.literal(1)
.case()
.when(1, "one")
.when(2, "two")
.else_("other")
.end(),
param(
lambda: ibis.literal(1).cases((1, "one"), (2, "two"), else_="other"),
"one",
id="one_kwarg",
),
pytest.param(
lambda: ibis.literal(5).case().when(1, "one").when(2, "two").end(),
param(
lambda: ibis.literal(5).cases((1, "one"), (2, "two")),
None,
id="fallthrough",
),
Expand All @@ -94,13 +90,8 @@
np = pytest.importorskip("numpy")

df = batting.to_pandas()
expr = (
batting.RBI.case()
.when(5, "five")
.when(4, "four")
.when(3, "three")
.else_("could be good?")
.end()
expr = batting.RBI.cases(
(5, "five"), (4, "four"), (3, "three"), else_="could be good?"
)
result = expr.execute()
expected = np.select(
Expand All @@ -113,7 +104,7 @@


def test_ibis_cases_scalar():
expr = ibis.literal(5).case().when(5, "five").when(4, "four").end()
expr = ibis.literal(5).cases((5, "five"), (4, "four"))

Check warning on line 107 in ibis/backends/tests/test_conditionals.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/tests/test_conditionals.py#L107

Added line #L107 was not covered by tests
result = expr.execute()
assert result == "five"

Expand All @@ -128,12 +119,8 @@

t = batting
df = batting.to_pandas()
expr = (
ibis.case()
.when(t.RBI < 5, "really bad team")
.when(t.teamID == "PH1", "ph1 team")
.else_(t.teamID)
.end()
expr = ibis.cases(
(t.RBI < 5, "really bad team"), (t.teamID == "PH1", "ph1 team"), else_=t.teamID
)
result = expr.execute()
expected = np.select(
Expand All @@ -148,5 +135,45 @@
@pytest.mark.notimpl("clickhouse", reason="special case this and returns 'oops'")
def test_value_cases_null(con):
"""CASE x WHEN NULL never gets hit"""
e = ibis.literal(5).nullif(5).case().when(None, "oops").else_("expected").end()
e = ibis.literal(5).nullif(5).cases((None, "oops"), else_="expected")
assert con.execute(e) == "expected"


@pytest.mark.parametrize(
("example", "expected"),
[
param(lambda: ibis.case().when(True, "yes").end(), "yes", id="top-level-true"),
param(lambda: ibis.case().when(False, "yes").end(), None, id="top-level-false"),
param(
lambda: ibis.case().when(False, "yes").else_("no").end(),
"no",
id="top-level-false-value",
),
param(
lambda: ibis.literal("a").case().when("a", "yes").end(),
"yes",
id="method-true",
),
param(
lambda: ibis.literal("a").case().when("b", "yes").end(),
None,
id="method-false",
),
param(
lambda: ibis.literal("a").case().when("b", "yes").else_("no").end(),
"no",
id="method-false-value",
),
],
)
def test_ibis_case_still_works(con, example, expected):
# test that the soft-deprecated .case() method still works
# https://github.com/ibis-project/ibis/pull/9096
pd = pytest.importorskip("pandas")

with pytest.warns(FutureWarning):
expr = example()

result = con.execute(expr)

assert (expected is None and pd.isna(result)) or result == expected
Loading