Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Return-types #11] Jax interface support shot vectors #3234

Merged
merged 106 commits into from
Nov 5, 2022

Conversation

rmoyard
Copy link
Contributor

@rmoyard rmoyard commented Oct 27, 2022

Description of the Change:

This PR integrates shot vectors for QNode with JAX and derivates jacobian, hessian.

Benefits:

Users can use shot vectors everywhere with Jax QNode.

Example:

        import jax
        from jax.config import config

        config.update("jax_enable_x64", True)

        if diff_method == "adjoint":
            pytest.skip("Test does not supports adjoint because second order diff.")

        dev = qml.device(dev_name, wires=2, shots=(1, 10, 100))

        params = jax.numpy.array([0.1, 0.2])

        @qnode(dev, interface="jax", diff_method="parameter-shift", max_diff=2)
        def circuit(x):
            qml.RX(x[0], wires=[0])
            qml.RY(x[1], wires=[1])
            qml.CNOT(wires=[0, 1])
            return qml.expval(qml.PauliZ(0) @ qml.PauliX(1)), qml.probs(wires=[0, 1])

        hess = jax.hessian(circuit)(params)
        print(hess)
((DeviceArray([[ 0. , -0.5],
             [-0.5,  0. ]], dtype=float64), DeviceArray([[[-5.00000000e-01,  0.00000000e+00],
              [ 0.00000000e+00,  0.00000000e+00]],

             [[ 0.00000000e+00,  2.50000000e-01],
              [ 2.50000000e-01,  0.00000000e+00]],

             [[ 0.00000000e+00, -1.23259516e-32],
              [ 0.00000000e+00,  0.00000000e+00]],

             [[ 5.00000000e-01, -2.50000000e-01],
              [-2.50000000e-01,  0.00000000e+00]]], dtype=float64)), (DeviceArray([[-0.15 , -0.075],
             [-0.075,  0.   ]], dtype=float64), DeviceArray([[[-2.00000000e-01, -3.75000000e-02],
              [-3.75000000e-02,  0.00000000e+00]],

             [[-3.00000000e-01, -1.25000000e-02],
              [-1.25000000e-02,  0.00000000e+00]],

             [[ 3.75000000e-01,  5.00000000e-02],
              [ 5.00000000e-02,  0.00000000e+00]],

             [[ 1.25000000e-01, -5.55111512e-18],
              [ 0.00000000e+00,  0.00000000e+00]]], dtype=float64)), (DeviceArray([[-0.27 , -0.135],
             [-0.135, -0.26 ]], dtype=float64), DeviceArray([[[-0.325 , -0.0375],
              [-0.0375, -0.13  ]],

             [[-0.175 ,  0.035 ],
              [ 0.035 ,  0.13  ]],

             [[ 0.31  ,  0.0325],
              [ 0.0325,  0.    ]],

             [[ 0.19  , -0.03  ],
              [-0.03  ,  0.    ]]], dtype=float64)))

@rmoyard rmoyard added the review-ready 👌 PRs which are ready for review by someone from the core team. label Nov 2, 2022
Copy link
Contributor

@antalszava antalszava left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rmoyard this is amazing work! 🔥 Requesting changes because I think _shots_copies needs some improvements as commented, but the changes shouldn't be major. The PR otherwise is solid.

In addition, ideally, I think it would be best to separate the fix/improvement for compatibility with shot vector defs that contain tuples because it's more general than the JAX interface. Having said that, I see that making it stand-alone would require further testing.

I would be happy to help with this separation by adding unit tests, let me know. I think it's important that we separate such changes because otherwise they will be lost in some bigger scope PRs and it will be challenging to track them down.

Thanks for making the shot tuple fix btw!

pennylane/gradients/jvp.py Outdated Show resolved Hide resolved
pennylane/gradients/jvp.py Outdated Show resolved Hide resolved
pennylane/_qubit_device.py Show resolved Hide resolved
pennylane/gradients/finite_difference.py Outdated Show resolved Hide resolved
pennylane/interfaces/jax.py Show resolved Hide resolved
tests/returntypes/test_jax_qnode_new.py Outdated Show resolved Hide resolved
tests/returntypes/test_jax_qnode_new.py Outdated Show resolved Hide resolved
tests/returntypes/test_jax_qnode_new.py Outdated Show resolved Hide resolved
tests/returntypes/test_jax_qnode_new.py Outdated Show resolved Hide resolved
tests/returntypes/test_jax_qnode_new.py Outdated Show resolved Hide resolved
@rmoyard rmoyard requested a review from antalszava November 4, 2022 13:29
pennylane/interfaces/jax.py Outdated Show resolved Hide resolved
pennylane/interfaces/jax.py Show resolved Hide resolved
tests/returntypes/test_jax_qnode_new.py Outdated Show resolved Hide resolved
pennylane/interfaces/jax.py Outdated Show resolved Hide resolved
Copy link
Contributor

@antalszava antalszava left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @rmoyard, looks good! 🎉 Amazing work. 💯

Approving on the condition of:

  • Shot vector tests being separated into a file;
  • Mixing up the shots def to have the tuple be at the 1. pos (see suggestion).

The rest of the comments are more minor.

tests/returntypes/test_jax_qnode_new.py Outdated Show resolved Hide resolved
rmoyard and others added 5 commits November 4, 2022 15:32
Co-authored-by: antalszava <antalszava@gmail.com>
Co-authored-by: antalszava <antalszava@gmail.com>
Copy link
Contributor

@eddddddy eddddddy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @rmoyard, this is looking solid 💯

Just have a few minor questions, but otherwise this looks good on my end.

pennylane/_qubit_device.py Show resolved Hide resolved
pennylane/_device.py Outdated Show resolved Hide resolved
pennylane/gradients/jvp.py Show resolved Hide resolved
pennylane/gradients/finite_difference.py Outdated Show resolved Hide resolved
pennylane/gradients/parameter_shift.py Outdated Show resolved Hide resolved
pennylane/gradients/jvp.py Show resolved Hide resolved
@antalszava antalszava merged commit 21b35b6 into master Nov 5, 2022
@antalszava antalszava deleted the return_jax_shots_vector branch November 5, 2022 00:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
review-ready 👌 PRs which are ready for review by someone from the core team.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants