From e0fa647ebed6e2fffbd549a68834f2dc8727330c Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Fri, 12 Jul 2024 09:08:51 -0800 Subject: [PATCH] feat: move from .case() to .cases() Fixes https://github.com/ibis-project/ibis/issues/7280 --- docs/posts/ci-analysis/index.qmd | 14 +- docs/tutorials/ibis-for-sql-users.qmd | 24 +- .../clickhouse/tests/test_operators.py | 14 +- ibis/backends/dask/tests/test_operations.py | 58 ---- ibis/backends/impala/tests/test_case_exprs.py | 4 +- ibis/backends/pandas/executor.py | 2 + ibis/backends/pandas/tests/test_operations.py | 66 +---- ibis/backends/snowflake/tests/test_udf.py | 38 +-- ibis/backends/tests/sql/conftest.py | 4 +- .../test_case_in_projection/decompiled.py | 20 +- .../test_case_in_projection/out.sql | 2 +- .../test_sql/test_searched_case/out.sql | 2 +- ibis/backends/tests/sql/test_select_sql.py | 4 +- ibis/backends/tests/test_aggregation.py | 10 +- ibis/backends/tests/test_generic.py | 113 +++++++- ibis/backends/tests/test_sql.py | 20 +- ibis/backends/tests/test_string.py | 18 +- ibis/backends/tests/test_struct.py | 2 +- ibis/backends/tests/tpc/h/test_queries.py | 24 +- ibis/expr/api.py | 88 +++--- ibis/expr/decompile.py | 16 +- ibis/expr/operations/generic.py | 25 +- ibis/expr/operations/logical.py | 2 +- ibis/expr/types/generic.py | 260 +++++++++--------- ibis/expr/types/numeric.py | 9 +- ibis/expr/types/relations.py | 4 +- ibis/tests/expr/test_case.py | 181 +++++++----- ibis/tests/expr/test_value_exprs.py | 16 +- 28 files changed, 508 insertions(+), 532 deletions(-) diff --git a/docs/posts/ci-analysis/index.qmd b/docs/posts/ci-analysis/index.qmd index 5babc2c6d0c61..65159d5bbe9c7 100644 --- a/docs/posts/ci-analysis/index.qmd +++ b/docs/posts/ci-analysis/index.qmd @@ -203,14 +203,12 @@ Let's also give them some names that'll look nice on our plots. stats = stats.mutate( raw_improvements=_.has_poetry.cast("int") + _.has_team.cast("int") ).mutate( - improvements=( - _.raw_improvements.case() - .when(0, "None") - .when(1, "Poetry") - .when(2, "Poetry + Team Plan") - .else_("NA") - .end() - ), + improvements=_.raw_improvements.cases( + (0, "None"), + (1, "Poetry"), + (2, "Poetry + Team Plan"), + else_="NA", + ) team_plan=ibis.where(_.raw_improvements > 1, "Poetry + Team Plan", "None"), ) stats diff --git a/docs/tutorials/ibis-for-sql-users.qmd b/docs/tutorials/ibis-for-sql-users.qmd index 534090bfce649..6d8c9a556b407 100644 --- a/docs/tutorials/ibis-for-sql-users.qmd +++ b/docs/tutorials/ibis-for-sql-users.qmd @@ -473,11 +473,11 @@ semantics: case = ( t.one.cast("timestamp") .year() - .case() - .when(2015, "This year") - .when(2014, "Last year") - .else_("Earlier") - .end() + .cases( + (2015, "This year"), + (2014, "Last year"), + else_="Earlier", + ) ) expr = t.mutate(year_group=case) @@ -496,18 +496,16 @@ CASE END ``` -To do this, use `ibis.case`: +To do this, use `ibis.cases`: ```{python} -case = ( - ibis.case() - .when(t.two < 0, t.three * 2) - .when(t.two > 1, t.three) - .else_(t.two) - .end() +cases = ibis.cases( + (t.two < 0, t.three * 2), + (t.two > 1, t.three), + else_=t.two, ) -expr = t.mutate(cond_value=case) +expr = t.mutate(cond_value=cases) ibis.to_sql(expr) ``` diff --git a/ibis/backends/clickhouse/tests/test_operators.py b/ibis/backends/clickhouse/tests/test_operators.py index 4ca53a3d2b9f3..3ff07ce916a45 100644 --- a/ibis/backends/clickhouse/tests/test_operators.py +++ b/ibis/backends/clickhouse/tests/test_operators.py @@ -201,9 +201,7 @@ def test_ifelse(alltypes, df, op, pandas_op): def test_simple_case(con, alltypes, assert_sql): t = alltypes - expr = ( - t.string_col.case().when("foo", "bar").when("baz", "qux").else_("default").end() - ) + expr = t.string_col.cases(("foo", "bar"), ("baz", "qux"), else_="default") assert_sql(expr) assert len(con.execute(expr)) @@ -211,12 +209,10 @@ def test_simple_case(con, alltypes, assert_sql): def test_search_case(con, alltypes, assert_sql): t = alltypes - expr = ( - ibis.case() - .when(t.float_col > 0, t.int_col * 2) - .when(t.float_col < 0, t.int_col) - .else_(0) - .end() + expr = ibis.cases( + (t.float_col > 0, t.int_col * 2), + (t.float_col < 0, t.int_col), + else_=0, ) assert_sql(expr) diff --git a/ibis/backends/dask/tests/test_operations.py b/ibis/backends/dask/tests/test_operations.py index e43e5af454933..cf6bd9a9eb040 100644 --- a/ibis/backends/dask/tests/test_operations.py +++ b/ibis/backends/dask/tests/test_operations.py @@ -773,64 +773,6 @@ def q_fun(x, quantile): tm.assert_series_equal(result, expected, check_index=False) -def test_searched_case_scalar(client): - expr = ibis.case().when(True, 1).when(False, 2).end() - result = client.execute(expr) - expected = np.int8(1) - assert result == expected - - -def test_searched_case_column(batting, batting_pandas_df): - t = batting - df = batting_pandas_df - expr = ( - ibis.case() - .when(t.RBI < 5, "really bad team") - .when(t.teamID == "PH1", "ph1 team") - .else_(t.teamID) - .end() - ) - result = expr.execute() - expected = pd.Series( - np.select( - [df.RBI < 5, df.teamID == "PH1"], - ["really bad team", "ph1 team"], - df.teamID, - ) - ) - tm.assert_series_equal(result, expected, check_names=False) - - -def test_simple_case_scalar(client): - x = ibis.literal(2) - expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end() - result = client.execute(expr) - expected = np.int8(1) - assert result == expected - - -def test_simple_case_column(batting, batting_pandas_df): - t = batting - df = batting_pandas_df - expr = ( - t.RBI.case() - .when(5, "five") - .when(4, "four") - .when(3, "three") - .else_("could be good?") - .end() - ) - result = expr.execute() - expected = pd.Series( - np.select( - [df.RBI == 5, df.RBI == 4, df.RBI == 3], - ["five", "four", "three"], - "could be good?", - ) - ) - tm.assert_series_equal(result, expected, check_names=False) - - def test_table_distinct(t, df): expr = t[["dup_strings"]].distinct() result = expr.compile() diff --git a/ibis/backends/impala/tests/test_case_exprs.py b/ibis/backends/impala/tests/test_case_exprs.py index a195928b12214..360fbf9522c8b 100644 --- a/ibis/backends/impala/tests/test_case_exprs.py +++ b/ibis/backends/impala/tests/test_case_exprs.py @@ -14,13 +14,13 @@ def table(mockcon): @pytest.fixture def simple_case(table): - return table.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() + return table.g.cases(("foo", "bar"), ("baz", "qux"), else_="default") @pytest.fixture def search_case(table): t = table - return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end() + return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2)) @pytest.fixture diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index a3153d17b8b47..2563d4be77a5d 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -167,6 +167,8 @@ def visit(cls, op: ops.IsNan, arg): def visit( cls, op: ops.SearchedCase | ops.SimpleCase, cases, results, default, base=None ): + if not cases: + return default if base is not None: cases = tuple(base == case for case in cases) cases, _ = cls.asframe(cases, concat=False) diff --git a/ibis/backends/pandas/tests/test_operations.py b/ibis/backends/pandas/tests/test_operations.py index 293fc008a50e9..8fa1338132083 100644 --- a/ibis/backends/pandas/tests/test_operations.py +++ b/ibis/backends/pandas/tests/test_operations.py @@ -683,73 +683,9 @@ def test_summary_non_numeric(batting, batting_df): assert dict(result.iloc[0]) == expected -def test_searched_case_scalar(client): - expr = ibis.case().when(True, 1).when(False, 2).end() - result = client.execute(expr) - expected = np.int8(1) - assert result == expected - - -def test_searched_case_column(batting, batting_df): - t = batting - df = batting_df - expr = ( - ibis.case() - .when(t.RBI < 5, "really bad team") - .when(t.teamID == "PH1", "ph1 team") - .else_(t.teamID) - .end() - ) - result = expr.execute() - expected = pd.Series( - np.select( - [df.RBI < 5, df.teamID == "PH1"], - ["really bad team", "ph1 team"], - df.teamID, - ) - ) - tm.assert_series_equal(result, expected) - - -def test_simple_case_scalar(client): - x = ibis.literal(2) - expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end() - result = client.execute(expr) - expected = np.int8(1) - assert result == expected - - -def test_simple_case_column(batting, batting_df): - t = batting - df = batting_df - expr = ( - t.RBI.case() - .when(5, "five") - .when(4, "four") - .when(3, "three") - .else_("could be good?") - .end() - ) - result = expr.execute() - expected = pd.Series( - np.select( - [df.RBI == 5, df.RBI == 4, df.RBI == 3], - ["five", "four", "three"], - "could be good?", - ) - ) - tm.assert_series_equal(result, expected) - - def test_non_range_index(): def do_replace(col): - return col.cases( - ( - (1, "one"), - (2, "two"), - ), - default="unk", - ) + return col.cases((1, "one"), (2, "two"), else_="unk") df = pd.DataFrame( { diff --git a/ibis/backends/snowflake/tests/test_udf.py b/ibis/backends/snowflake/tests/test_udf.py index 4a59013cebece..2ee68897f4419 100644 --- a/ibis/backends/snowflake/tests/test_udf.py +++ b/ibis/backends/snowflake/tests/test_udf.py @@ -8,7 +8,6 @@ import pytest from pytest import param -import ibis import ibis.expr.datatypes as dt from ibis import udf @@ -122,36 +121,23 @@ def predict_price( df.columns = ["CARAT_SCALED", "CUT_ENCODED", "COLOR_ENCODED", "CLARITY_ENCODED"] return model.predict(df) - def cases(value, mapping): - """This should really be a top-level function or method.""" - expr = ibis.case() - for k, v in mapping.items(): - expr = expr.when(value == k, v) - return expr.end() - diamonds = con.tables.DIAMONDS expr = diamonds.mutate( predicted_price=predict_price( (_.carat - _.carat.mean()) / _.carat.std(), - cases( - _.cut, - { - c: i - for i, c in enumerate( - ("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1 - ) - }, + _.cut.cases( + (c, i) + for i, c in enumerate( + ("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1 + ) ), - cases(_.color, {c: i for i, c in enumerate("DEFGHIJ", start=1)}), - cases( - _.clarity, - { - c: i - for i, c in enumerate( - ("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"), - start=1, - ) - }, + _.color.cases((c, i) for i, c in enumerate("DEFGHIJ", start=1)), + _.clarity.cases( + (c, i) + for i, c in enumerate( + ("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"), + start=1, + ) ), ) ) diff --git a/ibis/backends/tests/sql/conftest.py b/ibis/backends/tests/sql/conftest.py index 04667e60e033b..06de1c83c8c08 100644 --- a/ibis/backends/tests/sql/conftest.py +++ b/ibis/backends/tests/sql/conftest.py @@ -164,13 +164,13 @@ def difference(con): @pytest.fixture(scope="module") def simple_case(con): t = con.table("alltypes") - return t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() + return t.g.cases(("foo", "bar"), ("baz", "qux"), else_="default") @pytest.fixture(scope="module") def search_case(con): t = con.table("alltypes") - return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end() + return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2)) @pytest.fixture(scope="module") diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py b/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py index 6058efaa962e6..35fb932c2248f 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py @@ -22,18 +22,14 @@ lit2 = ibis.literal("bar") result = alltypes.select( - alltypes.g.case() - .when(lit, lit2) - .when(lit1, ibis.literal("qux")) - .else_(ibis.literal("default")) - .end() - .name("col1"), - ibis.case() - .when(alltypes.g == lit, lit2) - .when(alltypes.g == lit1, alltypes.g) - .else_(ibis.literal(None).cast("string")) - .end() - .name("col2"), + alltypes.g.cases( + (lit, lit2), (lit1, ibis.literal("qux")), else_=ibis.literal("default") + ).name("col1"), + ibis.cases( + (alltypes.g == lit, lit2), + (alltypes.g == lit1, alltypes.g), + else_=ibis.literal(None), + ).name("col2"), alltypes.a, alltypes.b, alltypes.c, diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/out.sql index 3f14a3c53b882..2a412787b51ea 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/out.sql @@ -5,7 +5,7 @@ SELECT THEN 'bar' WHEN "t0"."g" = 'baz' THEN "t0"."g" - ELSE CAST(NULL AS TEXT) + ELSE NULL END AS "col2", "t0"."a", "t0"."b", diff --git a/ibis/backends/tests/sql/snapshots/test_sql/test_searched_case/out.sql b/ibis/backends/tests/sql/snapshots/test_sql/test_searched_case/out.sql index 1bbe6d29ebd73..a0fd1283a2ed2 100644 --- a/ibis/backends/tests/sql/snapshots/test_sql/test_searched_case/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_sql/test_searched_case/out.sql @@ -4,6 +4,6 @@ SELECT THEN "t0"."d" * CAST(2 AS TINYINT) WHEN "t0"."c" < CAST(0 AS TINYINT) THEN "t0"."a" * CAST(2 AS TINYINT) - ELSE CAST(NULL AS BIGINT) + ELSE NULL END AS "tmp" FROM "alltypes" AS "t0" \ No newline at end of file diff --git a/ibis/backends/tests/sql/test_select_sql.py b/ibis/backends/tests/sql/test_select_sql.py index 94a52017f763f..24893739fb6eb 100644 --- a/ibis/backends/tests/sql/test_select_sql.py +++ b/ibis/backends/tests/sql/test_select_sql.py @@ -397,8 +397,8 @@ def test_bool_bool(snapshot): def test_case_in_projection(alltypes, snapshot): t = alltypes - expr = t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() - expr2 = ibis.case().when(t.g == "foo", "bar").when(t.g == "baz", t.g).end() + expr = t.g.cases(("foo", "bar"), ("baz", "qux"), else_=("default")) + expr2 = ibis.cases((t.g == "foo", "bar"), (t.g == "baz", t.g)) expr = t[expr.name("col1"), expr2.name("col2"), t] snapshot.assert_match(to_sql(expr), "out.sql") diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 5b9627181514f..1456424213d1d 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -649,7 +649,7 @@ def test_first_last(backend, alltypes, method, filtered): # To sanely test this we create a column that is a mix of nulls and a # single value (or a single value after filtering is applied). if filtered: - new = alltypes.int_col.cases([(3, 30), (4, 40)]) + new = alltypes.int_col.cases((3, 30), (4, 40)) where = _.int_col == 3 else: new = (alltypes.int_col == 3).ifelse(30, None) @@ -681,7 +681,7 @@ def test_arbitrary(backend, alltypes, df, filtered): # _something_ we create a column that is a mix of nulls and a single value # (or a single value after filtering is applied). if filtered: - new = alltypes.int_col.cases([(3, 30), (4, 40)]) + new = alltypes.int_col.cases((3, 30), (4, 40)) where = _.int_col == 3 else: new = (alltypes.int_col == 3).ifelse(30, None) @@ -1434,9 +1434,7 @@ def collect_udf(v): def test_binds_are_cast(alltypes): expr = alltypes.aggregate( - high_line_count=( - alltypes.string_col.case().when("1-URGENT", 1).else_(0).end().sum() - ) + high_line_count=alltypes.string_col.cases(("1-URGENT", 1), else_=0).sum() ) expr.execute() @@ -1482,7 +1480,7 @@ def test_agg_name_in_output_column(alltypes): def test_grouped_case(backend, con): table = ibis.memtable({"key": [1, 1, 2, 2], "value": [10, 30, 20, 40]}) - case_expr = ibis.case().when(table.value < 25, table.value).else_(ibis.null()).end() + case_expr = ibis.cases((table.value < 25, table.value), else_=ibis.null()) expr = ( table.group_by(k="key") diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index ecdc7f83594c6..a9a5c4e389322 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -389,12 +389,11 @@ def test_case_where(backend, alltypes, df): table = alltypes table = table.mutate( new_col=( - ibis.case() - .when(table["int_col"] == 1, 20) - .when(table["int_col"] == 0, 10) - .else_(0) - .end() - .cast("int64") + ibis.cases( + (table["int_col"] == 1, 20), + (table["int_col"] == 0, 10), + else_=0, + ).cast("int64") ) ) @@ -427,9 +426,7 @@ def test_select_filter_mutate(backend, alltypes, df): # Prepare the float_col so that filter must execute # before the cast to get the correct result. - t = t.mutate( - float_col=ibis.case().when(t["bool_col"], t["float_col"]).else_(np.nan).end() - ) + t = t.mutate(float_col=ibis.cases((t["bool_col"], t["float_col"]), else_=np.nan)) # Actual test t = t[t.columns] @@ -2307,7 +2304,32 @@ def test_sample_with_seed(backend): ), ], ) -def test_value_cases(con, inp, exp): +def test_value_cases_deprecated(con, inp, exp): + with pytest.warns(FutureWarning): + i = inp() + result = con.execute(i) + if exp is None: + assert pd.isna(result) + else: + assert result == exp + + +@pytest.mark.parametrize( + "inp, exp", + [ + pytest.param( + lambda: ibis.literal(1).cases((1, "one"), (2, "two"), else_="other"), + "one", + id="one_kwarg", + ), + pytest.param( + lambda: ibis.literal(5).cases((1, "one"), (2, "two")), + None, + id="fallthrough", + ), + ], +) +def test_value_cases_scalar(con, inp, exp): result = con.execute(inp()) if exp is None: assert pd.isna(result) @@ -2315,6 +2337,49 @@ def test_value_cases(con, inp, exp): assert result == exp +@pytest.mark.broken( + "exasol", + reason="the int64 RBI column is .to_pandas()ed to an object column, which is incomparable to ints", + raises=AssertionError, +) +def test_value_cases_column(batting): + df = batting.to_pandas() + expr = batting.RBI.cases( + (5, "five"), (4, "four"), (3, "three"), else_="could be good?" + ) + result = expr.execute() + expected = np.select( + [df.RBI == 5, df.RBI == 4, df.RBI == 3], + ["five", "four", "three"], + "could be good?", + ) + + assert Counter(result) == Counter(expected) + + +@pytest.mark.broken( + ["sqlite", "exasol"], + reason="the int64 RBI column is .to_pandas()ed to an object column, which is incomparable to 5", + raises=TypeError, +) +def test_ibis_cases_column(batting): + t = batting + df = batting.to_pandas() + expr = ibis.cases( + (t.RBI < 5, "really bad team"), + (t.teamID == "PH1", "ph1 team"), + else_=t.teamID, + ) + result = expr.execute() + expected = np.select( + [df.RBI < 5, df.teamID == "PH1"], + ["really bad team", "ph1 team"], + df.teamID, + ) + + assert Counter(result) == Counter(expected) + + def test_substitute(backend): val = "400" t = backend.functional_alltypes @@ -2328,6 +2393,34 @@ def test_substitute(backend): assert expr["subs_count"].execute()[0] == t.count().execute() // 10 +@pytest.mark.broken("clickhouse", reason="special case this and returns 'oops'") +def test_value_cases_null(con): + """CASE x WHEN NULL never gets hit""" + e = ibis.literal(5).nullif(5).cases((None, "oops"), else_="expected") + assert con.execute(e) == "expected" + + +@pytest.mark.broken("pyspark", reason="raises a ResourceWarning that we can't catch") +def test_case(con): + # just to make sure that the deprecated .case() method still works + with pytest.warns(FutureWarning, match=".cases"): + assert con.execute(ibis.case().when(True, "yes").end()) == "yes" + with pytest.warns(FutureWarning, match=".cases"): + assert pd.isna(con.execute(ibis.case().when(False, "yes").end())) + with pytest.warns(FutureWarning, match=".cases"): + assert con.execute(ibis.case().when(False, "yes").else_("no").end()) == "no" + + with pytest.warns(FutureWarning, match=".cases"): + assert con.execute(ibis.literal("a").case().when("a", "yes").end()) == "yes" + with pytest.warns(FutureWarning, match=".cases"): + assert pd.isna(con.execute(ibis.literal("a").case().when("b", "yes").end())) + with pytest.warns(FutureWarning, match=".cases"): + assert ( + con.execute(ibis.literal("a").case().when("b", "yes").else_("no").end()) + == "no" + ) + + @pytest.mark.notimpl( ["dask", "pandas", "polars"], raises=NotImplementedError, reason="not a SQL backend" ) diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index 777cfa3db8bb3..c0bfb53f18170 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -59,16 +59,16 @@ def test_group_by_has_index(backend, snapshot): ) expr = countries.group_by( cont=( - _.continent.case() - .when("NA", "North America") - .when("SA", "South America") - .when("EU", "Europe") - .when("AF", "Africa") - .when("AS", "Asia") - .when("OC", "Oceania") - .when("AN", "Antarctica") - .else_("Unknown continent") - .end() + _.continent.cases( + ("NA", "North America"), + ("SA", "South America"), + ("EU", "Europe"), + ("AF", "Africa"), + ("AS", "Asia"), + ("OC", "Oceania"), + ("AN", "Antarctica"), + else_="Unknown continent", + ) ) ).agg(total_pop=_.population.sum()) sql = str(ibis.to_sql(expr, dialect=backend.name())) diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index ceb9fdc77711b..e2352055ea9e7 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -507,9 +507,9 @@ def uses_java_re(t): id="length", ), param( - lambda t: t.int_col.cases([(1, "abcd"), (2, "ABCD")], "dabc").startswith( - "abc" - ), + lambda t: t.int_col.cases( + (1, "abcd"), (2, "ABCD"), else_="dabc" + ).startswith("abc"), lambda t: t.int_col == 1, id="startswith", marks=[ @@ -517,7 +517,7 @@ def uses_java_re(t): ], ), param( - lambda t: t.int_col.cases([(1, "abcd"), (2, "ABCD")], "dabc").endswith( + lambda t: t.int_col.cases((1, "abcd"), (2, "ABCD"), else_="dabc").endswith( "bcd" ), lambda t: t.int_col == 1, @@ -693,11 +693,9 @@ def test_re_replace_global(con): @pytest.mark.notimpl(["druid"], raises=ValidationError) def test_substr_with_null_values(backend, alltypes, df): table = alltypes.mutate( - substr_col_null=ibis.case() - .when(alltypes["bool_col"], alltypes["string_col"]) - .else_(None) - .end() - .substr(0, 2) + substr_col_null=ibis.cases( + (alltypes["bool_col"], alltypes["string_col"]), else_=None + ).substr(0, 2) ) result = table.execute() @@ -910,7 +908,7 @@ def test_levenshtein(con, right): @pytest.mark.parametrize( "expr", [ - param(ibis.case().when(True, "%").end(), id="case"), + param(ibis.cases((True, "%")), id="case"), param(ibis.ifelse(True, "%", ibis.null()), id="ifelse"), ], ) diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index 6a7429a6c2ffa..c29582d02d2e8 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -148,7 +148,7 @@ def test_collect_into_struct(alltypes): @pytest.mark.notimpl(["flink"], raises=Py4JJavaError, reason="not implemented in ibis") def test_field_access_after_case(con): s = ibis.struct({"a": 3}) - x = ibis.case().when(True, s).else_(ibis.struct({"a": 4})).end() + x = ibis.cases((True, s), else_=ibis.struct({"a": 4})) y = x.a assert con.to_pandas(y) == 3 diff --git a/ibis/backends/tests/tpc/h/test_queries.py b/ibis/backends/tests/tpc/h/test_queries.py index 208969eec7e41..bb0fd74ece777 100644 --- a/ibis/backends/tests/tpc/h/test_queries.py +++ b/ibis/backends/tests/tpc/h/test_queries.py @@ -261,9 +261,7 @@ def test_08(part, supplier, region, lineitem, orders, customer, nation): ] ) - q = q.mutate( - nation_volume=ibis.case().when(q.nation == NATION, q.volume).else_(0).end() - ) + q = q.mutate(nation_volume=ibis.cases((q.nation == NATION, q.volume), else_=0)) gq = q.group_by([q.o_year]) q = gq.aggregate(mkt_share=q.nation_volume.sum() / q.volume.sum()) q = q.order_by([q.o_year]) @@ -389,19 +387,15 @@ def test_12(orders, lineitem): gq = q.group_by([q.l_shipmode]) q = gq.aggregate( - high_line_count=( - q.o_orderpriority.case() - .when("1-URGENT", 1) - .when("2-HIGH", 1) - .else_(0) - .end() + high_line_count=q.o_orderpriority.cases( + ("1-URGENT", 1), + ("2-HIGH", 1), + else_=0, ).sum(), - low_line_count=( - q.o_orderpriority.case() - .when("1-URGENT", 0) - .when("2-HIGH", 0) - .else_(1) - .end() + low_line_count=q.o_orderpriority.cases( + ("1-URGENT", 0), + ("2-HIGH", 0), + else_=1, ).sum(), ) q = q.order_by(q.l_shipmode) diff --git a/ibis/expr/api.py b/ibis/expr/api.py index 92e14066cac8f..093b6bb0d7044 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -67,6 +67,7 @@ "array", "asc", "case", + "cases", "coalesce", "connect", "cross_join", @@ -1108,56 +1109,71 @@ def interval( return functools.reduce(operator.add, intervals) +@util.deprecated(instead="use ibis.cases() instead", as_of="9.1") def case() -> bl.SearchedCaseBuilder: - """Begin constructing a case expression. + """DEPRECATED: Use `ibis.cases()` instead.""" + return bl.SearchedCaseBuilder() + + +@deferrable +def cases(*branches: tuple[Any, Any], else_: Any | None = None) -> ir.Value: + """Create a multi-branch if-else expression. - Use the `.when` method on the resulting object followed by `.end` to create a - complete case expression. + Goes through each (condition, value) pair in `branches`, finding the + first condition that evaluates to True, and returns the corresponding + value. If no condition is True, returns `else_`. Returns ------- - SearchedCaseBuilder - A builder object to use for constructing a case expression. + Value + A value expression See Also -------- - [`Value.case()`](./expression-generic.qmd#ibis.expr.types.generic.Value.case) + [`Value.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) Examples -------- >>> import ibis - >>> from ibis import _ >>> ibis.options.interactive = True - >>> t = ibis.memtable( - ... { - ... "left": [1, 2, 3, 4], - ... "symbol": ["+", "-", "*", "/"], - ... "right": [5, 6, 7, 8], - ... } - ... ) - >>> t.mutate( - ... result=( - ... ibis.case() - ... .when(_.symbol == "+", _.left + _.right) - ... .when(_.symbol == "-", _.left - _.right) - ... .when(_.symbol == "*", _.left * _.right) - ... .when(_.symbol == "/", _.left / _.right) - ... .end() - ... ) - ... ) - ┏━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓ - ┃ left ┃ symbol ┃ right ┃ result ┃ - ┡━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩ - │ int64 │ string │ int64 │ float64 │ - ├───────┼────────┼───────┼─────────┤ - │ 1 │ + │ 5 │ 6.0 │ - │ 2 │ - │ 6 │ -4.0 │ - │ 3 │ * │ 7 │ 21.0 │ - │ 4 │ / │ 8 │ 0.5 │ - └───────┴────────┴───────┴─────────┘ - + >>> v = ibis.memtable({"values": [1, 2, 1, 2, 3, 2, 4]}).values + >>> ibis.cases((v == 1, "a"), (v > 2, "b"), else_="unk").name("cases") + ┏━━━━━━━━┓ + ┃ cases ┃ + ┡━━━━━━━━┩ + │ string │ + ├────────┤ + │ a │ + │ unk │ + │ a │ + │ unk │ + │ b │ + │ unk │ + │ b │ + └────────┘ + >>> ibis.cases( + ... (v % 2 == 0, "divisible by 2"), + ... (v % 3 == 0, "divisible by 3"), + ... (v % 4 == 0, "shadowed by the 2 case"), + ... ).name("cases") + ┏━━━━━━━━━━━━━━━━┓ + ┃ cases ┃ + ┡━━━━━━━━━━━━━━━━┩ + │ string │ + ├────────────────┤ + │ NULL │ + │ divisible by 2 │ + │ NULL │ + │ divisible by 2 │ + │ divisible by 3 │ + │ divisible by 2 │ + │ divisible by 2 │ + └────────────────┘ """ - return bl.SearchedCaseBuilder() + if not branches: + raise ValueError("At least one branch is required") + cases, results = zip(*branches) + return ops.SearchedCase(cases=cases, results=results, default=else_).to_expr() def now() -> ir.TimestampScalar: diff --git a/ibis/expr/decompile.py b/ibis/expr/decompile.py index 7d87550a9bcf6..2d73d35b2570c 100644 --- a/ibis/expr/decompile.py +++ b/ibis/expr/decompile.py @@ -304,16 +304,12 @@ def ifelse(op, bool_expr, true_expr, false_null_expr): @translate.register(ops.SimpleCase) @translate.register(ops.SearchedCase) -def switch_case(op, cases, results, default, base=None): - out = f"{base}.case()" if base else "ibis.case()" - - for case, result in zip(cases, results): - out = f"{out}.when({case}, {result})" - - if default is not None: - out = f"{out}.else_({default})" - - return f"{out}.end()" +def switch_cases(op, cases, results, default, base=None): + namespace = f"{base}" if base else "ibis" + case_strs = [f"({case}, {result})" for case, result in zip(cases, results)] + cases_str = ", ".join(case_strs) + else_str = f", else_={default}" if default is not None else "" + return f"{namespace}.cases({cases_str}{else_str})" _infix_ops = { diff --git a/ibis/expr/operations/generic.py b/ibis/expr/operations/generic.py index cfa3ece1b456f..4012216f0aa71 100644 --- a/ibis/expr/operations/generic.py +++ b/ibis/expr/operations/generic.py @@ -293,11 +293,22 @@ class SimpleCase(Value): results: VarTuple[Value] default: Value - shape = rlz.shape_like("base") - - def __init__(self, cases, results, **kwargs): + def __init__(self, base, cases, results, default): assert len(cases) == len(results) - super().__init__(cases=cases, results=results, **kwargs) + + for case in cases: + if not rlz.comparable(base, case): + raise TypeError( + f"Base expression {rlz._arg_type_error_format(base)} and " + f"case {rlz._arg_type_error_format(case)} are not comparable" + ) + + super().__init__(base=base, cases=cases, results=results, default=default) + + @attribute + def shape(self): + exprs = [self.base, *self.cases, *self.results, self.default] + return rlz.highest_precedence_shape(exprs) @attribute def dtype(self): @@ -315,14 +326,12 @@ class SearchedCase(Value): def __init__(self, cases, results, default): assert len(cases) == len(results) - if default.dtype.is_null(): - default = Cast(default, rlz.highest_precedence_dtype(results)) super().__init__(cases=cases, results=results, default=default) @attribute def shape(self): - # TODO(kszucs): can be removed after making Sequence iterable - return rlz.highest_precedence_shape(self.cases) + exprs = [*self.cases, *self.results, self.default] + return rlz.highest_precedence_shape(exprs) @attribute def dtype(self): diff --git a/ibis/expr/operations/logical.py b/ibis/expr/operations/logical.py index 22235ca811245..fb5192570baad 100644 --- a/ibis/expr/operations/logical.py +++ b/ibis/expr/operations/logical.py @@ -154,7 +154,7 @@ class IfElse(Value): Equivalent to ```python - bool_expr.case().when(True, true_expr).else_(false_or_null_expr) + bool_expr.cases((True, true_expr), else_=false_or_null_expr) ``` Many backends implement this as a built-in function. diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 9c7fb9e6fe996..3da85a4e66b52 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1,6 +1,7 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +import warnings +from collections.abc import Sequence from typing import TYPE_CHECKING, Any from public import public @@ -10,6 +11,7 @@ import ibis.expr.builders as bl import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis import util from ibis.common.deferred import Deferred, _, deferrable from ibis.common.grounds import Singleton from ibis.expr.rewrites import rewrite_window_input @@ -723,19 +725,18 @@ def substitute( └────────┴──────────────┘ """ if isinstance(value, dict): - expr = ibis.case() - try: - null_replacement = value.pop(None) - except KeyError: - pass - else: - expr = expr.when(self.isnull(), null_replacement) - for k, v in value.items(): - expr = expr.when(self == k, v) + branches = list(value.items()) else: - expr = self.case().when(value, replacement) - - return expr.else_(else_ if else_ is not None else self).end() + branches = [(value, replacement)] + nulls = [(k, v) for k, v in branches if k is None] + nonnulls = [(k, v) for k, v in branches if k is not None] + if nulls: + null_replacement = nulls[0][1] + self = self.fill_null(null_replacement) + else_ = else_ if else_ is not None else self + if not nonnulls: + return else_ + return self.cases(*nonnulls, else_=else_) def over( self, @@ -871,99 +872,80 @@ def notnull(self) -> ir.BooleanValue: """ return ops.NotNull(self).to_expr() + @util.deprecated(instead="use Value.cases() instead", as_of="9.1") def case(self) -> bl.SimpleCaseBuilder: - """Create a SimpleCaseBuilder to chain multiple if-else statements. - - Add new search expressions with the `.when()` method. These must be - comparable with this column expression. Conclude by calling `.end()`. - - Returns - ------- - SimpleCaseBuilder - A case builder - - See Also - -------- - [`Value.substitute()`](./expression-generic.qmd#ibis.expr.types.generic.Value.substitute) - [`ibis.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) - [`ibis.case()`](./expression-generic.qmd#ibis.case) + """DEPRECATED: Use `self.cases()` instead.""" + return bl.SimpleCaseBuilder(self.op()) - Examples - -------- - >>> import ibis - >>> ibis.options.interactive = True - >>> x = ibis.examples.penguins.fetch().head(5)["sex"] - >>> x - ┏━━━━━━━━┓ - ┃ sex ┃ - ┡━━━━━━━━┩ - │ string │ - ├────────┤ - │ male │ - │ female │ - │ female │ - │ NULL │ - │ female │ - └────────┘ - >>> x.case().when("male", "M").when("female", "F").else_("U").end() - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, ('male', 'female'), ('M', 'F'), 'U') ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├──────────────────────────────────────────────────────┤ - │ M │ - │ F │ - │ F │ - │ U │ - │ F │ - └──────────────────────────────────────────────────────┘ - - Cases not given result in the ELSE case - - >>> x.case().when("male", "M").else_("OTHER").end() - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, ('male',), ('M',), 'OTHER') ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├─────────────────────────────────────────────┤ - │ M │ - │ OTHER │ - │ OTHER │ - │ OTHER │ - │ OTHER │ - └─────────────────────────────────────────────┘ - - If you don't supply an ELSE, then NULL is used - - >>> x.case().when("male", "M").end() - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, ('male',), ('M',), Cast(None, string)) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├────────────────────────────────────────────────────────┤ - │ M │ - │ NULL │ - │ NULL │ - │ NULL │ - │ NULL │ - └────────────────────────────────────────────────────────┘ - """ - import ibis.expr.builders as bl + @staticmethod + def _norm_cases_args(*args, **kwargs): + # TODO: remove in v10.0 once we have a deprecation cycle + # before, the API for Value.cases() was + # def cases( + # self, + # case_result_pairs: Iterable[tuple[Value, Value]], + # default: Value | None = None, + # ) -> Value: + # Now it is + # def cases( + # self, + # *branches: tuple[Value, Value], + # else_: Value | None = None, + # ) -> Value: + # This method normalizes the arguments to the new API. + using_old_api = False + branches = [] + else_ = None + if len(args) >= 1: + first_arg = args[0] + first_arg = util.promote_list(first_arg) + if len(first_arg) > 0 and isinstance(first_arg[0], tuple): + # called as .cases([(test, result), ...], ) + using_old_api = True + branches = first_arg + else_ = args[1] if len(args) == 2 else None + else: + # called as .cases((test, result), ...) + branches = list(args) + + if "case_result_pairs" in kwargs: + using_old_api = True + branches = list(kwargs["case_result_pairs"]) + elif "branches" in kwargs: + branches = list(kwargs["branches"]) + + if "default" in kwargs: + using_old_api = True + else_ = kwargs["default"] + elif "else_" in kwargs: + else_ = kwargs["else_"] + + if using_old_api: + warnings.warn( + "You are using the old API for `cases()`. Please see" + " https://ibis-project.org/reference/expression-generic" + " on how to upgrade to the new API.", + FutureWarning, + ) + return branches, else_ - return bl.SimpleCaseBuilder(self.op()) + def cases(self, *args, **kwargs) -> Value: # noqa: D417 + """Create a multi-branch if-else expression. - def cases( - self, - case_result_pairs: Iterable[tuple[ir.BooleanValue, Value]], - default: Value | None = None, - ) -> Value: - """Create a case expression in one shot. + This is semantically equivalent to + CASE self + WHEN test_val0 THEN result0 + WHEN test_val1 THEN result1 + ELSE else_ + END Parameters ---------- - case_result_pairs - Conditional-result pairs - default + branches + (test_val, result) pairs. We look through the test values in order + and return the result corresponding to the first test value that + matches `self`. If none match, we return `else_`. + else_ Value to return if none of the case conditions are true Returns @@ -974,48 +956,56 @@ def cases( See Also -------- [`Value.substitute()`](./expression-generic.qmd#ibis.expr.types.generic.Value.substitute) - [`ibis.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) - [`ibis.case()`](./expression-generic.qmd#ibis.case) + [`ibis.cases()`](./expression-generic.qmd#ibis.cases) Examples -------- >>> import ibis >>> ibis.options.interactive = True - >>> t = ibis.memtable({"values": [1, 2, 1, 2, 3, 2, 4]}) - >>> t - ┏━━━━━━━━┓ - ┃ values ┃ - ┡━━━━━━━━┩ - │ int64 │ - ├────────┤ - │ 1 │ - │ 2 │ - │ 1 │ - │ 2 │ - │ 3 │ - │ 2 │ - │ 4 │ - └────────┘ - >>> number_letter_map = ((1, "a"), (2, "b"), (3, "c")) - >>> t.values.cases(number_letter_map, default="unk").name("replace") - ┏━━━━━━━━━┓ - ┃ replace ┃ - ┡━━━━━━━━━┩ - │ string │ - ├─────────┤ - │ a │ - │ b │ - │ a │ - │ b │ - │ c │ - │ b │ - │ unk │ - └─────────┘ + >>> t = ibis.memtable( + ... { + ... "left": [5, 6, 7, 8, 9, 10], + ... "symbol": ["+", "-", "*", "/", "bogus", None], + ... "right": [1, 2, 3, 4, 5, 6], + ... } + ... ) + + Note we never hit the `None` case, because `x = NULL` is always NULL, + which is not truthy. If you want to replace NULLs, you should use + `.fillna(-999)` prior to `cases()`. + + >>> t.mutate( + ... result=( + ... t.symbol.cases( + ... ("+", t.left + t.right), + ... ("-", t.left - t.right), + ... ("*", t.left * t.right), + ... ("/", t.left / t.right), + ... (None, -999), + ... ) + ... ) + ... ) + ┏━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓ + ┃ left ┃ symbol ┃ right ┃ result ┃ + ┡━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩ + │ int64 │ string │ int64 │ float64 │ + ├───────┼────────┼───────┼─────────┤ + │ 5 │ + │ 1 │ 6.0 │ + │ 6 │ - │ 2 │ 4.0 │ + │ 7 │ * │ 3 │ 21.0 │ + │ 8 │ / │ 4 │ 2.0 │ + │ 9 │ bogus │ 5 │ NULL │ + │ 10 │ NULL │ 6 │ NULL │ + └───────┴────────┴───────┴─────────┘ """ - builder = self.case() - for case, result in case_result_pairs: - builder = builder.when(case, result) - return builder.else_(default).end() + branches, else_ = self._norm_cases_args(*args, **kwargs) + + if not branches: + raise ValueError("At least one branch is required") + cases, results = zip(*branches) + return ops.SimpleCase( + base=self, cases=cases, results=results, default=else_ + ).to_expr() def collect(self, where: ir.BooleanValue | None = None) -> ir.ArrayScalar: """Aggregate this expression's elements into an array. diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index 36da3740a9444..820d731973bb4 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -1,6 +1,5 @@ from __future__ import annotations -import functools from typing import TYPE_CHECKING, Literal from public import public @@ -1143,13 +1142,7 @@ def label(self, labels: Iterable[str], nulls: str | None = None) -> ir.StringVal │ 2 │ c │ └───────┴─────────┘ """ - return ( - functools.reduce( - lambda stmt, inputs: stmt.when(*inputs), enumerate(labels), self.case() - ) - .else_(nulls) - .end() - ) + return self.cases(*enumerate(labels), else_=nulls) @public diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index f577d4177d804..5772e427234ea 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -2925,9 +2925,7 @@ def info(self) -> Table: for pos, colname in enumerate(self.columns): col = self[colname] typ = col.type() - agg = self.select( - isna=ibis.case().when(col.isnull(), 1).else_(0).end() - ).agg( + agg = self.select(isna=ibis.cases((col.isnull(), 1), else_=0)).agg( name=lit(colname), type=lit(str(typ)), nullable=lit(typ.nullable), diff --git a/ibis/tests/expr/test_case.py b/ibis/tests/expr/test_case.py index 89e5b3b7df7b4..7fa7006c64690 100644 --- a/ibis/tests/expr/test_case.py +++ b/ibis/tests/expr/test_case.py @@ -1,11 +1,14 @@ from __future__ import annotations +import pytest + import ibis import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.types as ir from ibis import _ -from ibis.tests.util import assert_equal, assert_pickle_roundtrip +from ibis.common.annotations import SignatureValidationError +from ibis.tests.util import assert_pickle_roundtrip def test_ifelse_method(table): @@ -44,72 +47,63 @@ def test_ifelse_function_deferred(table): assert res.equals(sol) -def test_simple_case_expr(table): - case1, result1 = "foo", table.a - case2, result2 = "bar", table.c - default_result = table.b - - expr1 = table.g.lower().cases( - [(case1, result1), (case2, result2)], default=default_result - ) - - expr2 = ( - table.g.lower() - .case() - .when(case1, result1) - .when(case2, result2) - .else_(default_result) - .end() - ) +def test_err_on_bad_args(table): + with pytest.raises(ValueError): + ibis.cases((True,)) + with pytest.raises(ValueError): + ibis.cases((True, 3, 4)) + with pytest.raises(ValueError): + ibis.cases((True, 3, 4)) + with pytest.raises(ValueError): + ibis.cases((True, 3), 5) - assert_equal(expr1, expr2) - assert isinstance(expr1, ir.IntegerColumn) + with pytest.raises(ValueError): + table.a.cases(("foo",)) + with pytest.raises(ValueError): + table.a.cases(("foo", 3, 4)) + with pytest.raises(ValueError): + table.a.cases(("foo", 3, 4)) + with pytest.raises(TypeError): + table.a.cases(("foo", 3), 5) def test_multiple_case_expr(table): - expr = ( - ibis.case() - .when(table.a == 5, table.f) - .when(table.b == 128, table.b * 2) - .when(table.c == 1000, table.e) - .else_(table.d) - .end() + expr = ibis.cases( + (table.a == 5, table.f), + (table.b == 128, table.b * 2), + (table.c == 1000, table.e), + else_=table.d, ) # deferred cases - deferred = ( - ibis.case() - .when(_.a == 5, table.f) - .when(_.b == 128, table.b * 2) - .when(_.c == 1000, table.e) - .else_(table.d) - .end() + deferred = ibis.cases( + (_.a == 5, table.f), + (_.b == 128, table.b * 2), + (_.c == 1000, table.e), + else_=table.d, ) expr2 = deferred.resolve(table) # deferred results - expr3 = ( - ibis.case() - .when(table.a == 5, _.f) - .when(table.b == 128, _.b * 2) - .when(table.c == 1000, _.e) - .else_(table.d) - .end() - .resolve(table) - ) + expr3 = ibis.cases( + (table.a == 5, _.f), + (table.b == 128, _.b * 2), + (table.c == 1000, _.e), + else_=table.d, + ).resolve(table) # deferred default - expr4 = ( - ibis.case() - .when(table.a == 5, table.f) - .when(table.b == 128, table.b * 2) - .when(table.c == 1000, table.e) - .else_(_.d) - .end() - .resolve(table) + expr4 = ibis.cases( + (table.a == 5, table.f), + (table.b == 128, table.b * 2), + (table.c == 1000, table.e), + else_=_.d, + ).resolve(table) + + assert ( + repr(deferred) + == "cases(((_.a == 5), ), ((_.b == 128), ), ((_.c == 1000), ), else_=)" ) - - assert repr(deferred) == "" assert expr.equals(expr2) assert expr.equals(expr3) assert expr.equals(expr4) @@ -130,13 +124,11 @@ def test_pickle_multiple_case_node(table): result3 = table.e default = table.d - expr = ( - ibis.case() - .when(case1, result1) - .when(case2, result2) - .when(case3, result3) - .else_(default) - .end() + expr = ibis.cases( + (case1, result1), + (case2, result2), + (case3, result3), + else_=default, ) op = expr.op() @@ -144,7 +136,7 @@ def test_pickle_multiple_case_node(table): def test_simple_case_null_else(table): - expr = table.g.case().when("foo", "bar").end() + expr = table.g.cases(("foo", "bar")) op = expr.op() assert isinstance(expr, ir.StringColumn) @@ -154,8 +146,8 @@ def test_simple_case_null_else(table): def test_multiple_case_null_else(table): - expr = ibis.case().when(table.g == "foo", "bar").end() - expr2 = ibis.case().when(table.g == "foo", _).end().resolve("bar") + expr = ibis.cases((table.g == "foo", "bar")) + expr2 = ibis.cases((table.g == "foo", _)).resolve("bar") assert expr.equals(expr2) @@ -172,8 +164,65 @@ def test_case_mixed_type(): name="my_data", ) - expr = ( - t0.three.case().when(0, "low").when(1, "high").else_("null").end().name("label") - ) + expr = t0.three.cases((0, "low"), (1, "high"), else_="null").name("label") result = t0[expr] assert result["label"].type().equals(dt.string) + + +def test_err_on_nonbool(table): + with pytest.raises(SignatureValidationError): + ibis.cases((table.a, "bar"), else_="baz") + + +@pytest.mark.xfail(reason="Literal('foo', type=bool), should error, but doesn't") +def test_err_on_nonbool2(): + with pytest.raises(SignatureValidationError): + ibis.cases(("foo", "bar"), else_="baz") + + +def test_err_on_noncomparable(table): + table.a.cases((8, "bar")) + table.a.cases((-8, "bar")) + # Can't compare an int to a string + with pytest.raises(TypeError): + table.a.cases(("foo", "bar")) + + +def test_err_on_empty_cases(table): + with pytest.raises(ValueError): + ibis.cases() + with pytest.raises(ValueError): + ibis.cases(else_=42) + with pytest.raises(ValueError): + table.a.cases() + with pytest.raises(ValueError): + table.a.cases(else_=42) + + +def test_dtype(): + assert isinstance(ibis.cases((True, "bar"), (False, "bar")), ir.StringValue) + assert isinstance(ibis.cases((True, None), else_="bar"), ir.StringValue) + with pytest.raises(TypeError): + assert ibis.cases((True, 5), (False, "bar")) + with pytest.raises(TypeError): + assert ibis.cases((True, 5), else_="bar") + + +def test_dshape(table): + assert isinstance(ibis.cases((True, "bar"), (False, "bar")), ir.Scalar) + assert isinstance(ibis.cases((True, None), else_="bar"), ir.Scalar) + assert isinstance(ibis.cases((table.b == 9, None), else_="bar"), ir.Column) + assert isinstance(ibis.cases((True, table.a), else_=42), ir.Column) + assert isinstance(ibis.cases((True, 42), else_=table.a), ir.Column) + assert isinstance(ibis.cases((True, table.a), else_=table.b), ir.Column) + + assert isinstance(ibis.literal(5).cases((9, 42)), ir.Scalar) + assert isinstance(ibis.literal(5).cases((9, 42), else_=43), ir.Scalar) + assert isinstance(ibis.literal(5).cases((table.a, 42)), ir.Column) + assert isinstance(ibis.literal(5).cases((9, table.a)), ir.Column) + assert isinstance(ibis.literal(5).cases((table.a, table.b)), ir.Column) + assert isinstance(ibis.literal(5).cases((9, 42), else_=table.a), ir.Column) + assert isinstance(table.a.cases((9, 42)), ir.Column) + assert isinstance(table.a.cases((table.b, 42)), ir.Column) + assert isinstance(table.a.cases((9, table.b)), ir.Column) + assert isinstance(table.a.cases((table.a, table.b)), ir.Column) diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index e7b57376052ad..26da30c326557 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -834,23 +834,11 @@ def test_substitute_dict(): subs = {"a": "one", "b": table.bar} result = table.foo.substitute(subs) - expected = ( - ibis.case() - .when(table.foo == "a", "one") - .when(table.foo == "b", table.bar) - .else_(table.foo) - .end() - ) + expected = table.foo.cases(("a", "one"), ("b", table.bar), else_=table.foo) assert_equal(result, expected) result = table.foo.substitute(subs, else_=ibis.null()) - expected = ( - ibis.case() - .when(table.foo == "a", "one") - .when(table.foo == "b", table.bar) - .else_(ibis.null()) - .end() - ) + expected = table.foo.cases(("a", "one"), ("b", table.bar), else_=ibis.null()) assert_equal(result, expected)