Skip to content

Commit

Permalink
test #2637 too
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Feb 13, 2024
1 parent eef8ad6 commit 16f947c
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions tests/test_2637_jax_tracer_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
from __future__ import annotations

import pytest

import awkward as ak

jax = pytest.importorskip("jax")


def test():
ak.jax.register_and_check()

jets = ak.Array(
[
[
{"pt": 1.0, "eta": 1.1, "phi": 0.1, "mass": 0.01},
{"pt": 2, "eta": 2.2, "phi": 0.2, "mass": 0.02},
],
[
{"pt": 4.0, "eta": 4.4, "phi": 0.4, "mass": 0.04},
{"pt": 5.0, "eta": 5.5, "phi": 0.5, "mass": 0.05},
{"pt": 6.0, "eta": 6.6, "phi": 0.6, "mass": 0.06},
],
],
backend="jax",
)

def correct_jets(jets, alpha):
new_pt = jets["pt"] + 25.0 * alpha
jets["pt"] = new_pt
return ak.sum(jets["pt"])

val, grad = jax.value_and_grad(correct_jets, argnums=1)(jets, 0.1)

assert val == 30.5
assert grad == 125.0

0 comments on commit 16f947c

Please sign in to comment.