Skip to content

Commit

Permalink
Allow vector-valued QNodes with JAX using host_callback.call (#2034)
Browse files Browse the repository at this point in the history
* getting shapes changed

* Adjust tests

* format

* vector valued test case

* copy the output_dim too

* qml.density_matrix

* var

* sampling dim

* squeeze

* prep for shape addition

* logic and tests first push

* first attempt for shape

* more single measure tests

* TODOs

* tests

* format

* polish comments

* correct call

* correct call

* use kwargs in measure.py

* lint

* output domain

* sampling

* re-add obs as a potential positional argument

* adjust the default; add more tests

* docstrings; qml.sample() out domain case

* qml.probs example with dummy device defining a cutoff

* error handling with tests

* format

* checking shape outputs

* coverage

* formatting

* changelog

* qml.density_matrix case

* comments; polishing

* format

* more test; order

* format

* Update pennylane/measure.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* format

* docstring

* change path

* shape intro order change

* correct previous resolution

* resolve

* docstring

* docstring

* QuantumTape

* no print

* merge master

* Update pennylane/tape/tape.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* linting

* filter warning ragged nested sequences

* Update pennylane/tape/tape.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* ignore different arguments linting error

* no test raises multi-expval

* more notes

* rename num_systems

* updates

* rename to shape and result_type

* numeric_type

* docstring

* linting issue fix

* test updates

* refactor

* docstring

* more tests

* changelog

* changelog

* comment

* test

* update docstring

* tests

* swap qml.eigvals use to checking the type of observables

* Update pennylane/measurements.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* example

* examples in the docstring

* Update pennylane/measurements.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* Update pennylane/measurements.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* Update pennylane/measurements.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* lru cache

* docstring

* Update pennylane/tape/tape.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* smaller shape doc for tape

* apply suggestions for the docstrings; have tape numeric_type as property

* docstring

* fix state and density_matrix for state vector

* state tests

* changelog & test for uncovered DefaultQubitTF line

* renames, new locations

* fwd shapes; docstring; example adjustments

* jax jit QNode test cases

* docstring

* lint test_wires.py

* no print; no error for now

* new shape and prints

* [skip ci]

* typo correct

* double prec in test file to have no warnings

* err and tests, [skip ci]

* Make numeric_type a property because we always want it to be defined and users/developers may forget its def when creating a MP

* remove print sttment

* skip jacobian calc for probs differentiation [skip ci]

* remove print statement and unused meas

* no need for xfail

* remove old scratch

* more

* more

* refactor shape logic

* Fix the include pattern used for running black

* Run black on pennylane

* Run black on tests

* changes

* Update error type; add tests for unrecognized return types

* [skip ci]

* multi-tape unit test

* gradient scalar cost func using vv QNode (jitting)

* gradient scalar cost func using vv QNode (jitting)

* unit test shapes

* changelog

* docs update, changelog

* docs

* Update doc/releases/changelog-dev.md

Co-authored-by: Ali Asadi <ali@xanadu.ai>

* Update tests/interfaces/test_jax.py

Co-authored-by: Ali Asadi <ali@xanadu.ai>

* type return

* note on shape

* refactor forward pass shape and dtype struct building

* correct changelog

* no int case anymore: would return a single element tuple; comment is not relevant, it's important measurements.py

Co-authored-by: Josh Izaac <josh146@gmail.com>
Co-authored-by: Ali Asadi <ali@xanadu.ai>
3 people authored Jun 10, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 1b788a6 commit 29557dd
Showing 12 changed files with 589 additions and 234 deletions.
15 changes: 9 additions & 6 deletions doc/introduction/interfaces/jax.rst
Original file line number Diff line number Diff line change
@@ -137,14 +137,17 @@ the ``@jax.jit`` decorator can be directly applied to the QNode.
the computation was just-in-time compiled. This is done by checking if any
of the input parameters were subject to a JAX transformation. If so, a
variant of the interface that supports the just-in-time compilation of
scalar-valued QNodes (i.e., those that have a single expectation value or
variance measurement) will be used. This is equivalent to passing
``interface='jax-jit'``.
QNodes will be used. This is equivalent to passing ``interface='jax-jit'``.

Computing the jacobian of vector-valued QNodes is not supported with the
JAX JIT interface. The output of vector-valued QNodes can, however, be used
in the definition of scalar-valued cost functions whose gradients can be
computed.

Specify ``interface='jax-python'`` to enforce support for computing the
forward and backward pass of vector-valued QNodes (e.g., QNodes with
probability, state or multiple expectation value measurements). This
option does not support just-in-time compilation.
backward pass of vector-valued QNodes (e.g., QNodes with probability, state
or multiple expectation value measurements). This option does not support
just-in-time compilation.


Randomness: Shots and Samples
53 changes: 50 additions & 3 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
@@ -4,6 +4,52 @@

<h3>New features since last release</h3>

* The JAX JIT interface now supports evaluating vector-valued QNodes
enabling new types of workflows to utilize the power of just-in-time
compilation for significant performance boosts.
[(#2034)](https://github.com/PennyLaneAI/pennylane/pull/2034)

Vector-valued QNodes include those with:
* `qml.probs`;
* `qml.state`;
* `qml.sample` or
* multiple `qml.expval` / `qml.var` measurements.

Consider a QNode that returns basis-state probabilities:
```python
dev = qml.device('default.qubit', wires=2)
x = jnp.array(0.543)
y = jnp.array(-0.654)

@jax.jit
@qml.qnode(dev, diff_method="parameter-shift", interface="jax")
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
qml.CNOT(wires=[0, 1])
return qml.probs(wires=[1])
```
The QNode can now be evaluated:
```pycon
>>> circuit(x, y)
DeviceArray([0.8397495 , 0.16025047], dtype=float32)
```
Computing the jacobian of vector-valued QNodes is not supported with the JAX
JIT interface. The output of vector-valued QNodes can, however, be used in
the definition of scalar-valued cost functions whose gradients can be
computed.

For example, one can define a cost function that outputs the first element of
the probability vector:
```python
def cost(x, y):
return circuit(x, y)[0]
```
```pycon
>>> jax.grad(cost, argnums=[0])(x, y)
(DeviceArray(-0.2050439, dtype=float32),)
```

* A new quantum information module is added. It includes a function for computing the reduced density matrix functions
for state vectors and density matrices.

@@ -663,6 +709,7 @@

This release contains contributions from (in alphabetical order):

Amintor Dusko, Ankit Khandelwal, Avani Bhardwaj, Chae-Yeun Park, Christian Gogolin, Christina Lee, David Wierichs, Edward Jiang, Guillermo Alonso-Linaje,
Jay Soni, Juan Miguel Arrazola, Katharine Hyatt, Korbinian Kottmann, Maria Schuld, Mikhail Andrenkov, Romain Moyard,
Qi Hu, Samuel Banning, Soran Jahangiri, Utkarsh Azad, WingCode
Amintor Dusko, Ankit Khandelwal, Avani Bhardwaj, Chae-Yeun Park, Christian Gogolin, Christina Lee, David Wierichs,
Edward Jiang, Guillermo Alonso-Linaje, Jay Soni, Juan Miguel Arrazola, Katharine Hyatt, Korbinian Kottmann,
Maria Schuld, Mikhail Andrenkov, Romain Moyard, Qi Hu, Samuel Banning, Soran Jahangiri, Utkarsh Azad, Antal Száva,
WingCode
5 changes: 4 additions & 1 deletion pennylane/devices/tests/test_wires.py
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@ def circuit():
# =====


# pylint: disable=too-few-public-methods
class TestWiresIntegration:
"""Test that the device integrates with PennyLane's wire management."""

@@ -54,7 +55,9 @@ class TestWiresIntegration:
],
)
@pytest.mark.parametrize("circuit_factory", [make_simple_circuit_expval])
def test_wires_expval(self, device, circuit_factory, wires1, wires2, tol):
def test_wires_expval(
self, device, circuit_factory, wires1, wires2, tol
): # pylint: disable=too-many-arguments
"""Test that the expectation of a circuit is independent from the wire labels used."""
dev1 = device(wires1)
dev2 = device(wires2)
2 changes: 1 addition & 1 deletion pennylane/interfaces/autograd.py
Original file line number Diff line number Diff line change
@@ -58,7 +58,7 @@ def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_d
tape.trainable_params = qml.math.get_trainable_indices(params)

# pylint misidentifies autograd.builtins as a dict
# pylint:disable=no-member
# pylint: disable=no-member
parameters = autograd.builtins.tuple(
[autograd.builtins.list(t.get_parameters()) for t in tapes]
)
53 changes: 31 additions & 22 deletions pennylane/interfaces/jax.py
Original file line number Diff line number Diff line change
@@ -260,6 +260,36 @@ def wrapped_exec_bwd(params, g):
return wrapped_exec(params)


def _raise_vector_valued_fwd(tapes):
"""Raises an error for vector-valued tapes in forward mode due to incorrect
results being produced.
There is an issue when jax.jacobian is being used, either due to issues
with tensor updating (TypeError: Updates tensor must be of rank 0; got 1)
or because jax.vmap introduces a redundant dimensionality in the result by
duplicating entries.
Example to the latter:
1. Output when using jax.jacobian:
DeviceArray([[-0.09983342, 0.01983384],\n
[-0.09983342, 0.01983384]], dtype=float64),
DeviceArray([[ 0. , -0.97517033],\n
[ 0. , -0.97517033]], dtype=float64)),
2. Expected output:
DeviceArray([[-0.09983342, 0.01983384],\n
[ 0. , -0.97517033]]
The output produced by this function matches 1.
"""
scalar_outputs = all(t.output_dim == 1 for t in tapes)
if not scalar_outputs:
raise InterfaceUnsupportedError(
"Computing the jacobian of vector-valued tapes is not supported currently in forward mode."
)


def _execute_with_fwd(
params,
tapes=None,
@@ -298,28 +328,7 @@ def wrapped_exec_bwd(params, g):
# Use the jacobian that was computed on the forward pass
jacs, params = params

# Note: there is an issue when jax.jacobian is being used, either due
# to issues with tensor updating (TypeError: Updates tensor must be of
# rank 0; got 1) or because jax.vmap introduces a redundant
# dimensionality in the result by duplicating entries
# Example to the latter:
#
# 1. Output when using jax.jacobian:
# DeviceArray([[-0.09983342, 0.01983384],\n
# [-0.09983342, 0.01983384]], dtype=float64),
# DeviceArray([[ 0. , -0.97517033],\n
# [ 0. , -0.97517033]], dtype=float64)),
#
# 2. Expected output:
# DeviceArray([[-0.09983342, 0.01983384],\n
# [ 0. , -0.97517033]]
#
# The output produced by this function matches 2.
scalar_outputs = all(t.output_dim == 1 for t in tapes)
if not scalar_outputs:
raise InterfaceUnsupportedError(
"Computing the jacobian of vector-valued tapes is not supported currently in forward mode."
)
_raise_vector_valued_fwd(tapes)

# Adjust the structure of how the jacobian is returned to match the
# non-forward mode cases
88 changes: 52 additions & 36 deletions pennylane/interfaces/jax_jit.py
Original file line number Diff line number Diff line change
@@ -23,9 +23,8 @@

import numpy as np
import pennylane as qml

from pennylane.measurements import Variance, Expectation, VnEntropy, MutualInfo
from pennylane.interfaces import InterfaceUnsupportedError
from pennylane.interfaces.jax import _raise_vector_valued_fwd

dtype = jnp.float64

@@ -62,8 +61,6 @@ def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_d
if max_diff > 1:
raise InterfaceUnsupportedError("The JAX interface only supports first order derivatives.")

_validate_tapes(tapes)

for tape in tapes:
# set the trainable parameters
params = tape.get_parameters(trainable_only=False)
@@ -92,31 +89,39 @@ def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_d
)


def _validate_tapes(tapes):
"""Validates that the input tapes are compatible with JAX support.
def _numeric_type_to_dtype(numeric_type):
"""Auxiliary function for converting from Python numeric types to JAX
dtypes based on the precision defined for the interface."""

single_precision = dtype is jnp.float32
if numeric_type is int:
return jnp.int32 if single_precision else jnp.int64

if numeric_type is float:
return jnp.float32 if single_precision else jnp.float64

# numeric_type is complex
return jnp.complex64 if single_precision else jnp.complex128

Raises:
ValueError: if tapes with non-scalar outputs were provided or a return
type other than variance and expectation value was used

def _extract_shape_dtype_structs(tapes, device):
"""Auxiliary function for defining the jax.ShapeDtypeStruct objects given
the tapes and the device.
The host_callback.call function expects jax.ShapeDtypeStruct objects to
describe the output of the function call.
"""
shape_dtypes = []

for t in tapes:
shape = t.shape(device)

if len(t.observables) != 1:
raise InterfaceUnsupportedError(
"The jittable JAX interface currently only supports quantum nodes with a single return type."
)
tape_dtype = _numeric_type_to_dtype(t.numeric_type)
shape_and_dtype = jax.ShapeDtypeStruct(tuple(shape), tape_dtype)

for o in t.observables:
return_type = o.return_type
if (
return_type is not Variance
and return_type is not Expectation
and return_type is not VnEntropy
and return_type is not MutualInfo
):
raise InterfaceUnsupportedError(
f"Only Variance and Expectation returns are supported for the jittable JAX interface, given {return_type}."
)
shape_dtypes.append(shape_and_dtype)

return shape_dtypes


def _execute(
@@ -128,8 +133,6 @@ def _execute(
gradient_kwargs=None,
_n=1,
): # pylint: disable=dangerous-default-value,unused-argument
# Only have scalar outputs
total_size = len(tapes)
total_params = np.sum([len(p) for p in params])

# Copy a given tape with operations and set parameters
@@ -147,8 +150,8 @@ def wrapper(p):
res, _ = execute_fn(new_tapes, **gradient_kwargs)
return res

shapes = [jax.ShapeDtypeStruct((1,), dtype) for _ in range(total_size)]
res = host_callback.call(wrapper, params, result_shape=shapes)
shape_dtype_structs = _extract_shape_dtype_structs(tapes, device)
res = host_callback.call(wrapper, params, result_shape=shape_dtype_structs)
return res

def wrapped_exec_fwd(params):
@@ -213,7 +216,10 @@ def jacs_wrapper(p):
jacs = gradient_fn(new_tapes, **gradient_kwargs)
return jacs

shapes = [jax.ShapeDtypeStruct((1, len(p)), dtype) for p in params]
shapes = [
jax.ShapeDtypeStruct((len(t.measurements), len(p)), dtype)
for t, p in zip(tapes, params)
]
jacs = host_callback.call(jacs_wrapper, params, result_shape=shapes)
vjps = [qml.gradients.compute_vjp(d, jac) for d, jac in zip(g, jacs)]
res = [[jnp.array(p) for p in v] for v in vjps]
@@ -232,10 +238,6 @@ def _execute_with_fwd(
gradient_kwargs=None,
_n=1,
): # pylint: disable=dangerous-default-value,unused-argument

# Only have scalar outputs
total_size = len(tapes)

@jax.custom_vjp
def wrapped_exec(params):
def wrapper(p):
@@ -251,12 +253,24 @@ def wrapper(p):
# On the forward execution return the jacobian too
return res, jacs

fwd_shapes = [jax.ShapeDtypeStruct((1,), dtype) for _ in range(total_size)]
jacobian_shape = [jax.ShapeDtypeStruct((1, len(p)), dtype) for p in params]
fwd_shape_dtype_struct = _extract_shape_dtype_structs(tapes, device)

jacobian_shape = [t.shape(device) + (len(p),) for t in tapes for p in params]
jac_dtypes = [_numeric_type_to_dtype(t.numeric_type) for t in tapes]

# Note: for qml.probs we'll first have a [1,dim] shape for the tape
# which is then reduced by the QNode
jacobian_shape = [
jax.ShapeDtypeStruct(tuple([shape]), dtype)
if isinstance(shape, int)
else jax.ShapeDtypeStruct(tuple(shape), dtype)
for shape, dtype in zip(jacobian_shape, jac_dtypes)
]

res, jacs = host_callback.call(
wrapper,
params,
result_shape=tuple([fwd_shapes, jacobian_shape]),
result_shape=tuple([fwd_shape_dtype_struct, jacobian_shape]),
)
return res, jacs

@@ -266,6 +280,8 @@ def wrapped_exec_fwd(params):

def wrapped_exec_bwd(params, g):

_raise_vector_valued_fwd(tapes)

# Use the jacobian that was computed on the forward pass
jacs, params = params

Loading

0 comments on commit 29557dd

Please sign in to comment.