Skip to content

Commit

Permalink
fix: extend TypeTracerArray with __eq__, __ne__, and __array_ufunc__. (
Browse files Browse the repository at this point in the history
…#2021)

* fix: extend TypeTracerArray with __eq__, __ne__, and __array_ufunc__.

* Added a test to ensure that this fix survives a refactoring.

* Fixes the 'invalid value encountered in cast' error.

* This should fix the precision error.
  • Loading branch information
jpivarski authored Dec 19, 2022
1 parent 9ab248c commit 2b85171
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 32 deletions.
19 changes: 19 additions & 0 deletions src/awkward/_typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,18 @@ def __getitem__(self, where):
else:
raise ak._errors.wrap_error(NotImplementedError(repr(where)))

def __eq__(self, other):
if isinstance(other, numbers.Real):
return TypeTracerArray(np.bool_, self._shape)
else:
return NotImplemented

def __ne__(self, other):
if isinstance(other, numbers.Real):
return TypeTracerArray(np.bool_, self._shape)
else:
return NotImplemented

def __lt__(self, other):
if isinstance(other, numbers.Real):
return TypeTracerArray(np.bool_, self._shape)
Expand Down Expand Up @@ -485,6 +497,13 @@ def reshape(self, *args):
def copy(self):
return self

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
replacements = [
numpy.empty(0, x.dtype) if hasattr(x, "dtype") else x for x in inputs
]
result = getattr(ufunc, method)(*replacements, **kwargs)
return TypeTracerArray(result.dtype, shape=self._shape)


class TypeTracer(ak._nplikes.NumpyLike):
known_data = False
Expand Down
3 changes: 2 additions & 1 deletion src/awkward/operations/ak_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ def to_numpy(array, *, allow_missing=True):
"ak.to_numpy",
dict(array=array, allow_missing=allow_missing),
):
return ak._util.to_arraylib(numpy, array, allow_missing)
with numpy.errstate(invalid="ignore"):
return ak._util.to_arraylib(numpy, array, allow_missing)
99 changes: 68 additions & 31 deletions tests/test_0355-mixins.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,41 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import numbers

import numpy as np
import pytest # noqa: F401
import pytest

import awkward as ak

to_list = ak.operations.to_list


def _assert_equal_enough(obtained, expected):
if isinstance(obtained, dict):
assert isinstance(expected, dict)
assert set(obtained.keys()) == set(expected.keys())
for key in obtained.keys():
_assert_equal_enough(obtained[key], expected[key])
elif isinstance(obtained, list):
assert isinstance(expected, list)
assert len(obtained) == len(expected)
for x, y in zip(obtained, expected):
_assert_equal_enough(x, y)
elif isinstance(obtained, tuple):
assert isinstance(expected, tuple)
assert len(obtained) == len(expected)
for x, y in zip(obtained, expected):
_assert_equal_enough(x, y)
elif isinstance(obtained, numbers.Real) and isinstance(expected, numbers.Real):
assert pytest.approx(obtained) == expected
else:
assert obtained == expected


def assert_equal_enough(obtained, expected):
_assert_equal_enough(obtained.tolist(), expected)


def test_make_mixins():
behavior = {}

Expand Down Expand Up @@ -82,37 +110,46 @@ def weighted_add(self, other):
[],
[{"x": 8, "y": 8.8}, {"x": 10, "y": 11.0}],
]
assert to_list(wone + wtwo) == [
assert_equal_enough(
wone + wtwo,
[
{
"x": 0.9524937500390619,
"y": 1.052493750039062,
"weight": 2.831969279439222,
},
{"x": 2.0, "y": 2.2, "weight": 5.946427498927402},
{
"x": 2.9516640394605282,
"y": 3.1549921183815837,
"weight": 8.632349833200564,
},
[
{
"x": 0.9524937500390619,
"y": 1.052493750039062,
"weight": 2.831969279439222,
},
{"x": 2.0, "y": 2.2, "weight": 5.946427498927402},
{
"x": 2.9516640394605282,
"y": 3.1549921183815837,
"weight": 8.632349833200564,
},
],
[],
[
{
"x": 3.9515600270076154,
"y": 4.206240108030463,
"weight": 11.533018588312771,
},
{"x": 5.0, "y": 5.5, "weight": 14.866068747318506},
],
],
[],
)
assert_equal_enough(
abs(one),
[
{
"x": 3.9515600270076154,
"y": 4.206240108030463,
"weight": 11.533018588312771,
},
{"x": 5.0, "y": 5.5, "weight": 14.866068747318506},
[1.4866068747318506, 2.973213749463701, 4.459820624195552],
[],
[5.946427498927402, 7.433034373659253],
],
]
assert to_list(abs(one)) == [
[1.4866068747318506, 2.973213749463701, 4.459820624195552],
[],
[5.946427498927402, 7.433034373659253],
]
assert to_list(one.distance(wtwo)) == [
[0.14142135623730953, 0.0, 0.31622776601683783],
[],
[0.4123105625617664, 0.0],
]
)
assert_equal_enough(
one.distance(wtwo),
[
[0.14142135623730953, 0.0, 0.31622776601683783],
[],
[0.4123105625617664, 0.0],
],
)
18 changes: 18 additions & 0 deletions tests/test_2021-check-TypeTracerArray-in-ak-where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import awkward as ak


def test():
conditionals = ak.Array([True, True, True, False, False, False])
unionarray = ak.Array([1, 2, 3, [4, 5], [], [6]])
otherarray = ak.Array(range(100, 106))
result = ak.where(conditionals, unionarray, otherarray)
assert result.tolist() == [1, 2, 3, 103, 104, 105]
assert str(result.type) == "6 * union[int64, var * int64]"

conditionals_tt = ak.Array(conditionals.layout.to_typetracer())
unionarray_tt = ak.Array(unionarray.layout.to_typetracer())
otherarray_tt = ak.Array(otherarray.layout.to_typetracer())
result_tt = ak.where(conditionals_tt, unionarray_tt, otherarray_tt)
assert str(result_tt.type) == "6 * union[int64, var * int64]"

0 comments on commit 2b85171

Please sign in to comment.