Skip to content

Commit

Permalink
fix: use to_backend_array() instead of asarray (#2592)
Browse files Browse the repository at this point in the history
* fix: use to_backend_array() instead of `asarray`

* test: ensure nan_to_num doesn't break

* refactor: remove `asarray`
  • Loading branch information
agoose77 authored Jul 28, 2023
1 parent d3e494a commit ce63bf2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
12 changes: 4 additions & 8 deletions src/awkward/operations/ak_nan_to_num.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,17 @@ def action(layout, backend, **kwargs):

def action(inputs, backend, **kwargs):
if all(isinstance(x, ak.contents.NumpyArray) for x in inputs):
tmp_layout = backend.nplike.asarray(inputs[0])
tmp_layout = inputs[0].data
if id(nan) in broadcasting_ids:
tmp_nan = backend.nplike.asarray(inputs[broadcasting_ids[id(nan)]])
tmp_nan = inputs[broadcasting_ids[id(nan)]].to_backend_array()
else:
tmp_nan = nan
if id(posinf) in broadcasting_ids:
tmp_posinf = backend.nplike.asarray(
inputs[broadcasting_ids[id(posinf)]]
)
tmp_posinf = inputs[broadcasting_ids[id(posinf)]].to_backend_array()
else:
tmp_posinf = posinf
if id(neginf) in broadcasting_ids:
tmp_neginf = backend.nplike.asarray(
inputs[broadcasting_ids[id(neginf)]]
)
tmp_neginf = inputs[broadcasting_ids[id(neginf)]].to_backend_array()
else:
tmp_neginf = neginf
return (
Expand Down
15 changes: 15 additions & 0 deletions tests/test_2591_nan_to_num.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import pytest

import awkward as ak

pytest.importorskip("jax")

ak.jax.register_and_check()


def test():
ak.nan_to_num(
ak.Array([1, 2, 3], backend="jax"), nan=ak.Array([1, 2, 3], backend="jax")
)

0 comments on commit ce63bf2

Please sign in to comment.