Skip to content

Commit

Permalink
fix(api): improve error message raised on improper calls to array `ma…
Browse files Browse the repository at this point in the history
…p` or `filter` (#8602)
  • Loading branch information
jcrist authored Mar 9, 2024
1 parent 71089fe commit 0236370
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 44 deletions.
27 changes: 17 additions & 10 deletions ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,15 +445,19 @@ def map(self, func: Deferred | Callable[[ir.Value], ir.Value]) -> ir.ArrayValue:
"""
if isinstance(func, Deferred):
name = "_"
else:
resolve = func.resolve
elif callable(func):
name = next(iter(inspect.signature(func).parameters.keys()))
resolve = func
else:
raise TypeError(
f"`func` must be a Deferred or Callable, got `{type(func).__name__}`"
)

parameter = ops.Argument(
name=name, shape=self.op().shape, dtype=self.type().value_type
)
if isinstance(func, Deferred):
body = func.resolve(parameter.to_expr())
else:
body = func(parameter.to_expr())
body = resolve(parameter.to_expr())
return ops.ArrayMap(self, param=parameter.param, body=body).to_expr()

def filter(
Expand Down Expand Up @@ -545,17 +549,20 @@ def filter(
"""
if isinstance(predicate, Deferred):
name = "_"
else:
resolve = predicate.resolve
elif callable(predicate):
name = next(iter(inspect.signature(predicate).parameters.keys()))
resolve = predicate
else:
raise TypeError(
f"`predicate` must be a Deferred or Callable, got `{type(predicate).__name__}`"
)
parameter = ops.Argument(
name=name,
shape=self.op().shape,
dtype=self.type().value_type,
)
if isinstance(predicate, Deferred):
body = predicate.resolve(parameter.to_expr())
else:
body = predicate(parameter.to_expr())
body = resolve(parameter.to_expr())
return ops.ArrayFilter(self, param=parameter.param, body=body).to_expr()

def contains(self, other: ir.Value) -> ir.BooleanValue:
Expand Down
52 changes: 18 additions & 34 deletions ibis/tests/expr/test_value_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,52 +1533,36 @@ def test_array_length_scalar():
assert isinstance(expr.op(), ops.ArrayLength)


def double_int(x):
return x * 2


def double_float(x):
return x * 2.0


def is_negative(x):
return x < 0


def test_array_map():
arr = ibis.array([1, 2, 3])

result_int = arr.map(double_int)
result_float = arr.map(double_float)
r1 = arr.map(_ * 2)
r2 = arr.map(lambda x: x * 2.0)
r3 = arr.map(functools.partial(lambda a, b: a + b, b=2))

assert result_int.type() == dt.Array(dt.int16)
assert result_float.type() == dt.Array(dt.float64)
assert r1.type() == dt.Array(dt.int16)
assert r2.type() == dt.Array(dt.float64)
assert r3.type() == dt.Array(dt.int16)


def test_array_map_partial():
arr = ibis.array([1, 2, 3])

def add(x, y):
return x + y

result = arr.map(functools.partial(add, y=2))
assert result.type() == dt.Array(dt.int16)
with pytest.raises(TypeError, match="must be a Deferred or Callable"):
# Non-deferred expressions aren't allowed
arr.map(arr[0])


def test_array_filter():
arr = ibis.array([1, 2, 3])
result = arr.filter(is_negative)
assert result.type() == arr.type()


def test_array_filter_partial():
arr = ibis.array([1, 2, 3])
r1 = arr.filter(lambda x: x < 0)
r2 = arr.filter(_ < 0)
r3 = arr.filter(functools.partial(lambda a, b: a == b, b=2))

def equal(x, y):
return x == y
assert r1.type() == arr.type()
assert r2.type() == arr.type()
assert r3.type() == arr.type()

result = arr.filter(functools.partial(equal, y=2))
assert result.type() == arr.type()
with pytest.raises(TypeError, match="must be a Deferred or Callable"):
# Non-deferred expressions aren't allowed
arr.filter(arr[0])


@pytest.mark.parametrize(
Expand Down

0 comments on commit 0236370

Please sign in to comment.