Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax tracers and adding scalars to arrays #2637

Closed
Tracked by #3
alexander-held opened this issue Aug 10, 2023 · 3 comments · Fixed by #3013
Closed
Tracked by #3

Jax tracers and adding scalars to arrays #2637

alexander-held opened this issue Aug 10, 2023 · 3 comments · Fixed by #3013
Assignees
Labels
autodiff Issue related to auto-differentiation bug The problem described is something that must be fixed

Comments

@alexander-held
Copy link
Member

alexander-held commented Aug 10, 2023

Version of Awkward Array

ce63bf2 (main from 2 weeks ago)

Description and code to reproduce

The following currently does not work. I am trying to differentiate through an operation that adds a scalar to an array:

import awkward as ak
import jax
import uproot

ak.jax.register_and_check()

ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
    "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"

def correct_jets(jets, alpha):
    jets = ak.Array(jets)
    new_pt = jets["pt"] + 25*alpha
    jets["pt"] = new_pt
    return jets

with uproot.open(ttbar_file) as f:
    arr = f["Events"].arrays(["Jet_pt","Jet_eta", "Jet_phi", "Jet_mass"])
    evtfilter = ak.num(arr["Jet_pt"]) >= 2
    jets = ak.zip(dict(zip(["pt","eta", "phi", "mass"], ak.unzip(arr))), with_name="Momentum4D")[evtfilter]
    jets = ak.to_backend(jets, "jax")


jax.value_and_grad(correct_jets, argnums=1)(jets, 0.1)

Result:

TypeError: Jax tracers cannot be used with `ak.from_arraylib`

This error occurred while calling

    numpy.add.__call__(
        <Array [[17.921875, 15.734375], ..., [...]] type='140 * var * float32'>
        JVPTracer-instance
    )

Am I setting this up wrong, or is this a bug?

@alexander-held alexander-held added the bug (unverified) The problem described would be a bug, but needs to be triaged label Aug 10, 2023
@alexander-held
Copy link
Member Author

Another thing I've been wondering about in this context: the trace points to an API ak.from_arraylib that I cannot find in the documentation. It lives in the internal ak._layout as ak._layout.from_arraylib, which explains why it is not documented. Is there a way to make this show up with its correct full path? I had a look at trying to understand the way this gets imported and am guessing that the ak.from_jax (which comes in via ak.operations) makes it have this name via

from awkward._layout import from_arraylib, wrap_layout
but then it is unclear to me why I can't call ak.from_arraylib myself too.

@jpivarski jpivarski added the autodiff Issue related to auto-differentiation label Oct 2, 2023
@jpivarski
Copy link
Member

Another autodiff issue to self-assign, @Saransh-cpp. Thanks!

@Saransh-cpp Saransh-cpp self-assigned this Jan 20, 2024
@Saransh-cpp
Copy link
Member

Linking #2556 (comment) because the error is originating from the exact same place.

Saransh-cpp added a commit to Saransh-cpp/awkward that referenced this issue Feb 8, 2024
@Saransh-cpp Saransh-cpp added bug The problem described is something that must be fixed and removed bug (unverified) The problem described would be a bug, but needs to be triaged labels Feb 12, 2024
Saransh-cpp added a commit to Saransh-cpp/awkward that referenced this issue Feb 13, 2024
Saransh-cpp added a commit that referenced this issue Feb 14, 2024
* fix: remove redundant(?) Jax.is_tracer_type check in _layout

* use importorskip

* test #2637 too
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
autodiff Issue related to auto-differentiation bug The problem described is something that must be fixed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants