Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow vector-valued QNodes with JAX using host_callback.call #2034

Merged
merged 174 commits into from
Jun 10, 2022
Merged
Changes from 1 commit
Commits
Show all changes
174 commits
Select commit Hold shift + click to select a range
d423f36
getting shapes changed
antalszava Dec 15, 2021
ee1aa35
Merge branch 'master' into qnode_vector_val_jax
antalszava Dec 15, 2021
c09e946
Adjust tests
antalszava Dec 15, 2021
2902a73
format
antalszava Dec 15, 2021
00105b1
vector valued test case
antalszava Dec 15, 2021
af25518
copy the output_dim too
antalszava Dec 16, 2021
840d736
qml.density_matrix
antalszava Dec 16, 2021
475bdc2
var
antalszava Dec 16, 2021
cb3cd1c
sampling dim
antalszava Dec 16, 2021
2fb7c30
squeeze
antalszava Dec 16, 2021
ac8c10b
prep for shape addition
antalszava Dec 16, 2021
3ae1e19
Merge branch 'master' into qnode_vector_val_jax
antalszava Dec 16, 2021
23a3036
logic and tests first push
antalszava Dec 16, 2021
7cfd0d2
Merge branch 'master' into tape_output_shape_getter
antalszava Dec 17, 2021
0de59fb
Merge branch 'master' into tape_output_shape_getter
antalszava Jan 26, 2022
033cd9d
first attempt for shape
antalszava Jan 26, 2022
4ca2e82
Merge branch 'master' into tape_output_shape_getter
antalszava Jan 26, 2022
a53f07c
more single measure tests
antalszava Jan 26, 2022
727eaf9
TODOs
antalszava Jan 26, 2022
5271d59
tests
antalszava Jan 26, 2022
08508fb
format
antalszava Jan 26, 2022
b6c5dcb
polish comments
antalszava Jan 26, 2022
4c7c4f8
Merge branch 'master' into tape_output_shape_getter
antalszava Jan 26, 2022
7be68ac
correct call
antalszava Jan 26, 2022
7de64c0
correct call
antalszava Jan 26, 2022
4bbab4d
use kwargs in measure.py
antalszava Jan 26, 2022
6a36160
lint
antalszava Jan 26, 2022
9cc7986
output domain
antalszava Jan 27, 2022
697b913
Merge branch 'qnode_vector_val_jax' of github.com:PennyLaneAI/pennyla…
antalszava Jan 27, 2022
bbe509e
sampling
antalszava Jan 27, 2022
654caf6
Merge branch 'master' into tape_output_shape_getter
antalszava Jan 27, 2022
09da010
re-add obs as a potential positional argument
antalszava Jan 27, 2022
bb276c0
adjust the default; add more tests
antalszava Jan 27, 2022
d975d6c
docstrings; qml.sample() out domain case
antalszava Jan 27, 2022
c111b3b
qml.probs example with dummy device defining a cutoff
antalszava Jan 27, 2022
e333e95
error handling with tests
antalszava Jan 27, 2022
b50d84d
format
antalszava Jan 27, 2022
767b540
merge and resolve conflicts
antalszava Jan 27, 2022
432de0b
checking shape outputs
antalszava Jan 27, 2022
9e6eda9
coverage
antalszava Jan 27, 2022
8472ff1
formatting
antalszava Jan 27, 2022
e4eef69
changelog
antalszava Jan 27, 2022
5b037c1
qml.density_matrix case
antalszava Jan 27, 2022
ed758d9
comments; polishing
antalszava Jan 27, 2022
2903828
format
antalszava Jan 27, 2022
a230c67
more test; order
antalszava Jan 28, 2022
514412d
format
antalszava Jan 28, 2022
9c3dba6
Update pennylane/measure.py
antalszava Jan 28, 2022
8471238
Resolve; use TapeErro; revert kwargs in MP class
antalszava Mar 16, 2022
3ecd16a
format
antalszava Mar 16, 2022
82afd6b
docstring
antalszava Mar 17, 2022
8e2298d
resolve conflict
antalszava Mar 17, 2022
56b3442
change path
antalszava Mar 17, 2022
edd7872
Merge branch 'master' into tape_output_shape_getter
antalszava Mar 17, 2022
b9fd998
shape intro order change
antalszava Mar 17, 2022
b325ec4
Merge branch 'master' into tape_output_shape_getter
antalszava Mar 18, 2022
7fc481a
Merge branch 'master' into tape_output_shape_getter
antalszava Mar 18, 2022
5c2c8cb
correct previous resolution
antalszava Mar 18, 2022
5568d3b
Merge branch 'tape_output_shape_getter' of github.com:PennyLaneAI/pen…
antalszava Mar 18, 2022
f135069
resolve
antalszava Mar 18, 2022
17742a6
docstring
antalszava Mar 18, 2022
4c9b6bc
docstring
antalszava Mar 18, 2022
b44db58
Merge branch 'master' into qnode_vector_val_jax
antalszava Mar 18, 2022
8b83830
Merge branch 'tape_output_shape_getter' into qnode_vector_val_jax
antalszava Mar 18, 2022
0580a2d
QuantumTape
antalszava Mar 18, 2022
6833caa
no print
antalszava Mar 18, 2022
330659f
merge master
antalszava Mar 24, 2022
823ff59
merge master
antalszava Mar 24, 2022
119f4e0
Merge branch 'master' into tape_output_shape_getter
antalszava Mar 28, 2022
cd85dd1
Merge branch 'master' into tape_output_shape_getter
antalszava Apr 11, 2022
488c126
Update pennylane/tape/tape.py
antalszava Apr 11, 2022
de2c278
linting
antalszava Apr 11, 2022
257fb95
Merge branch 'tape_output_shape_getter' of github.com:PennyLaneAI/pen…
antalszava Apr 11, 2022
f3b0cdc
filter warning ragged nested sequences
antalszava Apr 11, 2022
a48b25b
Update pennylane/tape/tape.py
antalszava Apr 11, 2022
e9cea33
Merge branch 'tape_output_shape_getter' into qnode_vector_val_jax
antalszava Apr 11, 2022
1b023b8
ignore different arguments linting error
antalszava Apr 11, 2022
cbdd88f
Merge branch 'master' into tape_output_shape_getter
antalszava Apr 11, 2022
f49c08d
Merge branch 'tape_output_shape_getter' of github.com:PennyLaneAI/pen…
antalszava Apr 11, 2022
ae0983a
Merge branch 'fix_opt_linting' into tape_output_shape_getter
antalszava Apr 11, 2022
53ddfc9
Merge branch 'tape_output_shape_getter' into qnode_vector_val_jax
antalszava Apr 11, 2022
1bcbac6
no test raises multi-expval
antalszava Apr 11, 2022
7c581f0
more notes
antalszava Apr 11, 2022
bf4123b
rename num_systems
antalszava Apr 11, 2022
7843b93
updates
antalszava Apr 12, 2022
7dde068
rename to shape and result_type
antalszava Apr 12, 2022
6238f61
numeric_type
antalszava Apr 12, 2022
a17c42e
docstring
antalszava Apr 12, 2022
6a49107
linting issue fix
antalszava Apr 12, 2022
0725ba5
test updates
antalszava Apr 12, 2022
0901012
refactor
antalszava Apr 12, 2022
d4d01d5
docstring
antalszava Apr 12, 2022
dee9a21
Merge branch 'master' into tape_output_shape_getter
antalszava Apr 12, 2022
57910e9
more tests
antalszava Apr 12, 2022
5bd6542
changelog
antalszava Apr 12, 2022
33bf00d
changelog
antalszava Apr 12, 2022
2cde89b
comment
antalszava Apr 12, 2022
3dc5175
test
antalszava Apr 12, 2022
2b7c92e
Merge branch 'master' into tape_output_shape_getter
antalszava Apr 12, 2022
6d819d7
update docstring
antalszava Apr 12, 2022
3aec877
Merge branch 'tape_output_shape_getter' of github.com:PennyLaneAI/pen…
antalszava Apr 12, 2022
1d79490
tests
antalszava Apr 12, 2022
502afba
Merge branch 'master' into tape_output_shape_getter
antalszava Apr 13, 2022
2393f68
swap qml.eigvals use to checking the type of observables
antalszava Apr 13, 2022
b128e24
Update pennylane/measurements.py
antalszava Apr 13, 2022
116621e
example
antalszava Apr 13, 2022
e314cf4
examples in the docstring
antalszava Apr 13, 2022
a5e69ff
Update pennylane/measurements.py
antalszava Apr 13, 2022
ec97966
Update pennylane/measurements.py
antalszava Apr 13, 2022
efa5489
Update pennylane/measurements.py
antalszava Apr 13, 2022
566819b
lru cache
antalszava Apr 13, 2022
5119804
docstring
antalszava Apr 13, 2022
5b287e0
Update pennylane/tape/tape.py
antalszava Apr 13, 2022
acd00b5
smaller shape doc for tape
antalszava Apr 13, 2022
7e4f99e
apply suggestions for the docstrings; have tape numeric_type as property
antalszava Apr 13, 2022
a43a9b7
docstring
antalszava Apr 13, 2022
7a91b7d
resolve
antalszava Apr 13, 2022
81f38b5
fix state and density_matrix for state vector
antalszava Apr 13, 2022
c1a32b5
state tests
antalszava Apr 13, 2022
bf5de4f
Merge branch 'master' into tape_output_shape_getter
antalszava Apr 13, 2022
d3b3dd1
changelog & test for uncovered DefaultQubitTF line
antalszava Apr 13, 2022
8e401fa
Merge branch 'tape_output_shape_getter' of github.com:PennyLaneAI/pen…
antalszava Apr 13, 2022
67a1782
Merge branch 'tape_output_shape_getter' into qnode_vector_val_jax
antalszava Apr 13, 2022
e2de3dc
renames, new locations
antalszava Apr 13, 2022
8a6d582
fwd shapes; docstring; example adjustments
antalszava Apr 13, 2022
40b2ff3
jax jit QNode test cases
antalszava Apr 13, 2022
2ec5937
resolve
antalszava Apr 13, 2022
4df8d1e
docstring
antalszava Apr 14, 2022
2580434
lint test_wires.py
antalszava Apr 14, 2022
b31286c
Merge branch 'master' into qnode_vector_val_jax
antalszava Apr 18, 2022
46ef0ef
Merge branch 'qnode_vector_val_jax' of github.com:PennyLaneAI/pennyla…
antalszava Apr 18, 2022
87f6b22
resolve [skip ci]
antalszava May 16, 2022
958fc33
no print; no error for now
antalszava May 16, 2022
af4a0d3
Merge branch 'master' into qnode_vector_val_jax
antalszava May 16, 2022
dcc5b9b
new shape and prints
antalszava May 17, 2022
a647815
Merge branch 'master' into qnode_vector_val_jax
antalszava May 17, 2022
f0acf11
[skip ci]
antalszava May 17, 2022
5f74159
typo correct
antalszava May 17, 2022
c323e14
double prec in test file to have no warnings
antalszava May 17, 2022
8d72ad1
err and tests, [skip ci]
antalszava May 17, 2022
0ee059b
Make numeric_type a property because we always want it to be defined …
antalszava May 19, 2022
a1bbe13
remove print sttment
antalszava May 19, 2022
329d95a
skip jacobian calc for probs differentiation [skip ci]
antalszava May 19, 2022
dbe787b
remove print statement and unused meas
antalszava May 20, 2022
15d7db0
no need for xfail
antalszava May 20, 2022
676788e
remove old scratch
antalszava May 20, 2022
7d81936
more
antalszava May 20, 2022
3df1370
more
antalszava May 25, 2022
d631d11
refactor shape logic
antalszava Jun 1, 2022
c7ecbf7
Fix the include pattern used for running black
antalszava Jun 1, 2022
1e5855d
Run black on pennylane
antalszava Jun 1, 2022
e19a409
Run black on tests
antalszava Jun 1, 2022
e28419a
Merge branch 'fix_black_include_pattern' into qnode_vector_val_jax
antalszava Jun 1, 2022
0d9958a
changes
antalszava Jun 1, 2022
844f719
Update error type; add tests for unrecognized return types
antalszava Jun 1, 2022
e1cf616
[skip ci]
antalszava Jun 2, 2022
3cef110
Merge branch 'master' into qnode_vector_val_jax
antalszava Jun 2, 2022
12f84b4
multi-tape unit test
antalszava Jun 2, 2022
11deb46
gradient scalar cost func using vv QNode (jitting)
antalszava Jun 2, 2022
7664bce
gradient scalar cost func using vv QNode (jitting)
antalszava Jun 2, 2022
64190dc
unit test shapes
antalszava Jun 2, 2022
edf2723
changelog
antalszava Jun 2, 2022
e16c885
docs update, changelog
antalszava Jun 2, 2022
e0dbd49
docs
antalszava Jun 2, 2022
ddbd728
changelog
antalszava Jun 2, 2022
5cea7db
Update doc/releases/changelog-dev.md
antalszava Jun 9, 2022
ad9dcdf
Update tests/interfaces/test_jax.py
antalszava Jun 9, 2022
1109c8c
type return
antalszava Jun 9, 2022
d4c9f46
note on shape
antalszava Jun 9, 2022
141f9bb
refactor forward pass shape and dtype struct building
antalszava Jun 9, 2022
550dfed
resolve
antalszava Jun 9, 2022
2c02d7e
correct changelog
antalszava Jun 9, 2022
1aa3247
resolve changelog; numeric type and shape updates for mutual info and…
antalszava Jun 10, 2022
cb485c3
no int case anymore: would return a single element tuple; comment is …
antalszava Jun 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
getting shapes changed
antalszava committed Dec 15, 2021
commit d423f3677b22e3b602aa7da2e634f875b481e119
31 changes: 27 additions & 4 deletions pennylane/interfaces/batch/jax.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@

import numpy as np
import pennylane as qml
from pennylane.operation import Variance, Expectation
from pennylane.operation import Variance, Expectation, State

dtype = jnp.float64

@@ -108,6 +108,29 @@ def _validate_tapes(tapes):
f"Only Variance and Expectation returns are supported for the JAX interface, given {return_type}."
)

def get_shapes_and_dtype(tapes, device):
dtype = jnp.float64

shapes = []
for t in tapes:
out_dim = t.output_dim
if out_dim == 0:
obs = t.observables[0]
if obs.return_type == State:
dtype = jnp.complex128
if obs.wires:
shapes.append(jax.ShapeDtypeStruct((1, obs.wires), dtype))
else:
shapes.append(jax.ShapeDtypeStruct((1, 2 ** len(device.wires)), dtype))

elif out_dim == 1:
shapes.append(jax.ShapeDtypeStruct((1,), dtype))
if out_dim > 1:
shapes.append(jax.ShapeDtypeStruct((1, out_dim), dtype))

return shapes, dtype



def _execute(
params,
@@ -119,7 +142,7 @@ def _execute(
_n=1,
): # pylint: disable=dangerous-default-value,unused-argument

_validate_tapes(tapes)
#_validate_tapes(tapes)

# Only have scalar outputs
total_size = len(tapes)
@@ -140,7 +163,7 @@ def wrapper(p):

return res

shapes = [jax.ShapeDtypeStruct((1,), dtype) for _ in range(total_size)]
shapes, _ = get_shapes_and_dtype(tapes, device)
res = host_callback.call(wrapper, params, result_shape=shapes)
return res

@@ -232,7 +255,7 @@ def _execute_with_fwd(
_n=1,
): # pylint: disable=dangerous-default-value,unused-argument

_validate_tapes(tapes)
#_validate_tapes(tapes)

# Only have scalar outputs
total_size = len(tapes)
42 changes: 42 additions & 0 deletions tests/interfaces/test_batch_jax.py
Original file line number Diff line number Diff line change
@@ -610,3 +610,45 @@ def cost_fn(a, p, device):
]
)
assert np.allclose(res, expected, atol=tol, rtol=0)

def test_independent_expval(self, execute_kwargs):
"""Tests computing an expectation value that is independent trainable
parameters."""
dev = qml.device("default.qubit", wires=2)
params = jnp.array([0.1, 0.2, 0.3])

def cost(a, cache):
with qml.tape.JacobianTape() as tape:
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
qml.RY(a[2], wires=0)
qml.expval(qml.PauliZ(1))

res = qml.interfaces.batch.execute(
[tape], dev, cache=cache, interface="jax", **execute_kwargs
)
return res[0][0]

res = jax.grad(cost)(params, cache=None)
assert res.shape == (1, 3)

def test_multiple_expvals(self, execute_kwargs):
"""Tests computing multiple expectation values in a tape."""
dev = qml.device("default.qubit", wires=2)
params = jnp.array([0.1, 0.2, 0.3])

def cost(a, cache):
with qml.tape.JacobianTape() as tape:
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
qml.RY(a[2], wires=0)
qml.expval(qml.PauliZ(0))
qml.expval(qml.PauliZ(1))

res = qml.interfaces.batch.execute(
[tape], dev, cache=cache, interface="jax", **execute_kwargs
)
return res[0]

res = jax.grad(cost)(params, cache=None)
assert res.shape == (2, 3)
89 changes: 60 additions & 29 deletions tests/interfaces/test_batch_jax_qnode.py
Original file line number Diff line number Diff line change
@@ -300,49 +300,80 @@ def circuit(a, p):
)
assert np.allclose(res, expected, atol=tol, rtol=0)

def test_multiple_outputs_raises(self, dev_name, diff_method, mode, tol):
"""Test executing a QNode that has multiple outputs raises an error."""
# def test_multiple_outputs_raises(self, dev_name, diff_method, mode, tol):
# """Test executing a QNode that has multiple outputs raises an error."""
# dev = qml.device(dev_name, wires=2)

# if diff_method == "backprop":
# pytest.skip("Test is not applicable for backprop")

# @qml.qnode(dev, interface="jax", diff_method=diff_method, mode=mode)
# def my_circuit(param):
# qml.RX(param, wires=0)
# qml.CNOT(wires=[0, 1])
# return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))

# with pytest.raises(
# ValueError,
# match="JAX interface currently only supports quantum nodes with a single return type",
# ):
# my_circuit(1)

# @pytest.mark.parametrize("ret", [qml.probs(wires=0), qml.state()])
# def test_not_expval_or_var_raises(self, dev_name, diff_method, mode, ret, tol):
# """Test executing a QNode that has a return type other than expval or
# var raises an error."""
# dev = qml.device(dev_name, wires=2)

# if diff_method == "backprop":
# pytest.skip("Test is not applicable for backprop")

# if diff_method == "adjoint":
# pytest.skip("Adjoint does not support states")

# @qml.qnode(dev, interface="jax", diff_method=diff_method, mode=mode)
# def my_circuit(param):
# qml.RX(param, wires=0)
# qml.CNOT(wires=[0, 1])
# return qml.apply(ret)

# with pytest.raises(
# ValueError,
# match="Only Variance and Expectation returns are supported for the JAX interface",
# ):
# my_circuit(1)

def test_probs(self, dev_name, diff_method, mode, tol):
"""Test executing a QNode that has a return type qml.probs."""
dev = qml.device(dev_name, wires=2)

if diff_method == "backprop":
pytest.skip("Test is not applicable for backprop")
if diff_method == "adjoint":
pytest.skip("Adjoint does not support states")

@qml.qnode(dev, interface="jax", diff_method=diff_method, mode=mode)
def my_circuit(param):
qml.RX(param, wires=0)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))

with pytest.raises(
ValueError,
match="JAX interface currently only supports quantum nodes with a single return type",
):
my_circuit(1)

@pytest.mark.parametrize("ret", [qml.probs(wires=0), qml.state()])
def test_not_expval_or_var_raises(self, dev_name, diff_method, mode, ret, tol):
"""Test executing a QNode that has a return type other than expval or
var raises an error."""
dev = qml.device(dev_name, wires=2)
qml.Hadamard(0)
return qml.probs(wires=0)

if diff_method == "backprop":
pytest.skip("Test is not applicable for backprop")
expected = jnp.array([0.5, 0.5])

assert jnp.allclose(my_circuit(1), expected)

def test_state(self, dev_name, diff_method, mode, tol):
"""Test executing a QNode that has a return type qml.state."""
dev = qml.device(dev_name, wires=1)

if diff_method == "adjoint":
pytest.skip("Adjoint does not support states")

@qml.qnode(dev, interface="jax", diff_method=diff_method, mode=mode)
def my_circuit(param):
qml.RX(param, wires=0)
qml.CNOT(wires=[0, 1])
return qml.apply(ret)
qml.Hadamard(0)
return qml.state()

with pytest.raises(
ValueError,
match="Only Variance and Expectation returns are supported for the JAX interface",
):
my_circuit(1)
expected = 1/jnp.sqrt(2) * jnp.ones(2)

assert jnp.allclose(my_circuit(1), expected)

class TestShotsIntegration:
"""Test that the QNode correctly changes shot value, and