Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove redundant(?) Jax.is_tracer_type check in _layout #3013

Merged
merged 3 commits into from
Feb 14, 2024

Conversation

Saransh-cpp
Copy link
Member

Fixes #2556 #2637

I am still not completely sure if removing the check is safe, but everything works for me locally, including the non-jax part of awkward.

#2637 now exits with a new error which is on the user side and not on the awkward side -

In [1]: import awkward as ak
   ...: import jax
   ...: import uproot
   ...: 
   ...: ak.jax.register_and_check()
   ...: 
   ...: ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
   ...:     "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"
   ...: 
   ...: def correct_jets(jets, alpha):
   ...:     jets = ak.Array(jets)
   ...:     new_pt = jets["pt"] + 25*alpha
   ...:     jets["pt"] = new_pt
   ...:     return jets
   ...: 
   ...: with uproot.open(ttbar_file) as f:
   ...:     arr = f["Events"].arrays(["Jet_pt","Jet_eta", "Jet_phi", "Jet_mass"])
   ...:     evtfilter = ak.num(arr["Jet_pt"]) >= 2
   ...:     jets = ak.zip(dict(zip(["pt","eta", "phi", "mass"], ak.unzip(arr))), with_name="Momentum4D")[evtfilter]
   ...:     jets = ak.to_backend(jets, "jax")
   ...: 
   ...: 
   ...: jax.value_and_grad(correct_jets, argnums=1)(jets, 0.1)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
    [... skipping hidden 1 frame]

/opt/homebrew/lib/python3.11/site-packages/jax/_src/core.py in get_aval(x)
   1497   else:
-> 1498     return concrete_aval(x)
   1499 

/opt/homebrew/lib/python3.11/site-packages/jax/_src/core.py in concrete_aval(x)
   1489     return concrete_aval(x.__jax_array__())
-> 1490   raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
   1491                    "type")

TypeError: Value <Array [[{eta: -3.1967773, ...}, ...], ...] type='140 * var * Momentum4D[et...'> with type <class 'awkward.highlevel.Array'> is not a valid JAX type

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
<ipython-input-1-d13645fe92dc> in <cell line: 0>()
     21 
     22 
---> 23 jax.value_and_grad(correct_jets, argnums=1)(jets, 0.1)

    [... skipping hidden 2 frame]

/opt/homebrew/lib/python3.11/site-packages/jax/_src/api.py in _check_scalar(x)
    753     aval = core.get_aval(x)
    754   except TypeError as e:
--> 755     raise TypeError(msg(f"was {x}")) from e
    756   else:
    757     if isinstance(aval, ShapedArray):

TypeError: Gradient only defined for scalar-output functions. Output was [[{eta: -3.1967773, phi: 2.9589844, mass: 3.3886719, pt: ..., ...}, ...], ...].

cc: @alexander-held

@codecov-commenter
Copy link

codecov-commenter commented Feb 8, 2024

Codecov Report

Attention: 1 lines in your changes are missing coverage. Please review.

Comparison is base (b749e49) 81.90% compared to head (16f947c) 81.91%.
Report is 15 commits behind head on main.

Additional details and impacted files
Files Coverage Δ
src/awkward/_connect/cuda/__init__.py 0.00% <ø> (ø)
src/awkward/_layout.py 86.48% <ø> (+0.24%) ⬆️
src/awkward/operations/ak_enforce_type.py 81.97% <ø> (ø)
src/awkward/operations/ak_to_dataframe.py 90.76% <100.00%> (+0.14%) ⬆️
src/awkward/_nplikes/typetracer.py 74.85% <0.00%> (ø)

@Saransh-cpp
Copy link
Member Author

Saransh-cpp commented Feb 8, 2024

Should I install jax in the workflows or are we not supposed to use non-dependency libraries in the tests? (or maybe pytest.importorskip?)

@alexander-held
Copy link
Member

#2637 now exits with a new error which is on the user side and not on the awkward side -

That looks fine, the reproducer should have returned a scalar from the function being differentiated. When doing that by using e.g. return ak.sum(jets.pt) in correct_jets, this code runs using the branch here. I think #2637 could then also be closed by this PR.

@agoose77
Copy link
Collaborator

agoose77 commented Feb 8, 2024

@Saransh-cpp importorskip should be used for any non-core dependency in our tests. Most jobs test jax, but some limited "minimal" runs do not, and will error if it's required.

@Saransh-cpp Saransh-cpp linked an issue Feb 8, 2024 that may be closed by this pull request
@agoose77
Copy link
Collaborator

@Saransh-cpp this has my approval if you want to "see what happens". My feeling is that our JAX user base is small enough that we can move fast and break things (I'm sure Jim agrees).

Copy link
Member

@jpivarski jpivarski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know of anything wrong with this. We can just remove it and "see what happens" because it only affects users of the JAX backend. In general, we can't move fast and break things with this library, but only for a limited context like this backend.

This guard was added by @agoose77 in #2389, and maybe he remembers why it was added.

I approve the PR to be merged. If, after merging, we find out that it was not a good idea, you know how to reinstate it.

@Saransh-cpp
Copy link
Member Author

Noted, thanks!

@Saransh-cpp Saransh-cpp merged commit dc170bb into scikit-hep:main Feb 14, 2024
39 checks passed
@Saransh-cpp Saransh-cpp deleted the rm-jax-tracer-error branch February 14, 2024 09:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Jax tracers and adding scalars to arrays Jax tracer TypeError in multiplication
5 participants