Skip to content

Commit

Permalink
fix(case): fix dshape, error on noncomparable and empty cases
Browse files Browse the repository at this point in the history
This is pinning down the expected behavior for cases before tackling
the case() to cases() switch in
ibis-project#9096
so that PR can be simpler
  • Loading branch information
NickCrews committed Jul 12, 2024
1 parent 8862979 commit 133d4cb
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ibis.case()
.when(alltypes.g == lit, lit2)
.when(alltypes.g == lit1, alltypes.g)
.else_(ibis.literal(None).cast("string"))
.else_(ibis.literal(None))
.end()
.name("col2"),
alltypes.a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ SELECT
THEN 'bar'
WHEN "t0"."g" = 'baz'
THEN "t0"."g"
ELSE CAST(NULL AS TEXT)
ELSE NULL
END AS "col2",
"t0"."a",
"t0"."b",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ SELECT
THEN "t0"."d" * CAST(2 AS TINYINT)
WHEN "t0"."c" < CAST(0 AS TINYINT)
THEN "t0"."a" * CAST(2 AS TINYINT)
ELSE CAST(NULL AS BIGINT)
ELSE NULL
END AS "tmp"
FROM "alltypes" AS "t0"
8 changes: 6 additions & 2 deletions ibis/expr/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def else_(self, result_expr: Any) -> Self:

def end(self) -> ir.Value | Deferred:
"""Finish the `CASE` expression."""
if not self.cases:
raise ValueError("At least one case must be specified")
return _finish_searched_case(self.cases, self.results, self.default)


Expand Down Expand Up @@ -98,8 +100,8 @@ def when(self, case_expr: Any, result_expr: Any) -> Self:

if not rlz.comparable(self.base, case_expr.op()):
raise TypeError(
f"Base expression {rlz._arg_type_error_format(self.base)} and "
f"case {rlz._arg_type_error_format(case_expr)} are not comparable"
f"Base expression {rlz.arg_type_error_format(self.base)} and "
f"case {rlz.arg_type_error_format(case_expr.op())} are not comparable"
)
return self.copy(
cases=self.cases + (case_expr,), results=self.results + (result_expr,)
Expand All @@ -118,6 +120,8 @@ def else_(self, result_expr: Any) -> Self:

def end(self) -> ir.Value:
"""Finish the `CASE` expression."""
if not self.cases:
raise ValueError("At least one case must be specified")
if (default := self.default) is None:
default = ibis.null().cast(rlz.highest_precedence_dtype(self.results))
return ops.SimpleCase(
Expand Down
13 changes: 7 additions & 6 deletions ibis/expr/operations/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,15 @@ class SimpleCase(Value):
results: VarTuple[Value]
default: Value

shape = rlz.shape_like("base")

def __init__(self, cases, results, **kwargs):
assert len(cases) == len(results)
super().__init__(cases=cases, results=results, **kwargs)

@attribute
def shape(self):
exprs = [self.base, *self.cases, *self.results, self.default]
return rlz.highest_precedence_shape(exprs)

@attribute
def dtype(self):
values = [*self.results, self.default]
Expand All @@ -315,14 +318,12 @@ class SearchedCase(Value):

def __init__(self, cases, results, default):
assert len(cases) == len(results)
if default.dtype.is_null():
default = Cast(default, rlz.highest_precedence_dtype(results))
super().__init__(cases=cases, results=results, default=default)

@attribute
def shape(self):
# TODO(kszucs): can be removed after making Sequence iterable
return rlz.highest_precedence_shape(self.cases)
exprs = [*self.cases, *self.results, self.default]
return rlz.highest_precedence_shape(exprs)

@attribute
def dtype(self):
Expand Down
12 changes: 6 additions & 6 deletions ibis/expr/operations/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def __init__(self, left, right):
"""
if not rlz.comparable(left, right):
raise IbisTypeError(
f"Arguments {rlz._arg_type_error_format(left)} and "
f"{rlz._arg_type_error_format(right)} are not comparable"
f"Arguments {rlz.arg_type_error_format(left)} and "
f"{rlz.arg_type_error_format(right)} are not comparable"
)
super().__init__(left=left, right=right)

Expand Down Expand Up @@ -121,13 +121,13 @@ class Between(Value):
def __init__(self, arg, lower_bound, upper_bound):
if not rlz.comparable(arg, lower_bound):
raise ValidationError(
f"Arguments {rlz._arg_type_error_format(arg)} and "
f"{rlz._arg_type_error_format(lower_bound)} are not comparable"
f"Arguments {rlz.arg_type_error_format(arg)} and "
f"{rlz.arg_type_error_format(lower_bound)} are not comparable"
)
if not rlz.comparable(arg, upper_bound):
raise ValidationError(
f"Arguments {rlz._arg_type_error_format(arg)} and "
f"{rlz._arg_type_error_format(upper_bound)} are not comparable"
f"Arguments {rlz.arg_type_error_format(arg)} and "
f"{rlz.arg_type_error_format(upper_bound)} are not comparable"
)
super().__init__(arg=arg, lower_bound=lower_bound, upper_bound=upper_bound)

Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _promote_interval_resolution(units: list[IntervalUnit]) -> IntervalUnit:
raise AssertionError("unreachable")


def _arg_type_error_format(op):
def arg_type_error_format(op: ops.Value) -> str:
if isinstance(op, ops.Literal):
return f"Literal({op.value}):{op.dtype}"
else:
Expand Down
68 changes: 68 additions & 0 deletions ibis/tests/expr/test_case.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import pytest

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import _
from ibis.common.annotations import SignatureValidationError
from ibis.tests.util import assert_equal, assert_pickle_roundtrip


Expand Down Expand Up @@ -44,6 +47,41 @@ def test_ifelse_function_deferred(table):
assert res.equals(sol)


def test_case_dshape(table):
assert isinstance(ibis.case().when(True, "bar").when(False, "bar").end(), ir.Scalar)
assert isinstance(ibis.case().when(True, None).else_("bar").end(), ir.Scalar)
assert isinstance(
ibis.case().when(table.b == 9, None).else_("bar").end(), ir.Column
)
assert isinstance(ibis.case().when(True, table.a).else_(42).end(), ir.Column)
assert isinstance(ibis.case().when(True, 42).else_(table.a).end(), ir.Column)
assert isinstance(ibis.case().when(True, table.a).else_(table.b).end(), ir.Column)

assert isinstance(ibis.literal(5).case().when(9, 42).end(), ir.Scalar)
assert isinstance(ibis.literal(5).case().when(9, 42).else_(43).end(), ir.Scalar)
assert isinstance(ibis.literal(5).case().when(table.a, 42).end(), ir.Column)
assert isinstance(ibis.literal(5).case().when(9, table.a).end(), ir.Column)
assert isinstance(ibis.literal(5).case().when(table.a, table.b).end(), ir.Column)
assert isinstance(
ibis.literal(5).case().when(9, 42).else_(table.a).end(), ir.Column
)
assert isinstance(table.a.case().when(9, 42).end(), ir.Column)
assert isinstance(table.a.case().when(table.b, 42).end(), ir.Column)
assert isinstance(table.a.case().when(9, table.b).end(), ir.Column)
assert isinstance(table.a.case().when(table.a, table.b).end(), ir.Column)


def test_case_dtype():
assert isinstance(
ibis.case().when(True, "bar").when(False, "bar").end(), ir.StringValue
)
assert isinstance(ibis.case().when(True, None).else_("bar").end(), ir.StringValue)
with pytest.raises(TypeError):
assert ibis.case().when(True, 5).when(False, "bar").end()
with pytest.raises(TypeError):
assert ibis.case().when(True, 5).else_("bar").end()


def test_simple_case_expr(table):
case1, result1 = "foo", table.a
case2, result2 = "bar", table.c
Expand Down Expand Up @@ -177,3 +215,33 @@ def test_case_mixed_type():
)
result = t0[expr]
assert result["label"].type().equals(dt.string)


def test_err_on_nonbool(table):
with pytest.raises(SignatureValidationError):
ibis.case().when(table.a, "bar").else_("baz").end()


@pytest.mark.xfail(reason="Literal('foo', type=bool), should error, but doesn't")
def test_err_on_nonbool2():
with pytest.raises(SignatureValidationError):
ibis.case().when("foo", "bar").else_("baz").end()


def test_err_on_noncomparable(table):
table.a.case().when(8, "bar").end()
table.a.case().when(-8, "bar").end()
# Can't compare an int to a string
with pytest.raises(TypeError):
table.a.case().when("foo", "bar").end()


def test_err_on_empty_cases(table):
with pytest.raises(ValueError):
ibis.case().end()
with pytest.raises(ValueError):
ibis.case(else_=42).end()
with pytest.raises(ValueError):
table.a.case().end()
with pytest.raises(ValueError):
table.a.case(else_=42).end()

0 comments on commit 133d4cb

Please sign in to comment.