Skip to content

Commit

Permalink
Fix possibility of falsy zero logic for work_wires in templates (#6720
Browse files Browse the repository at this point in the history
)

**Context:**

Follow up to #6713. 

Previously, `work_wires=0` was incorrectly being mapped to an empty
list,
```python
>>> qml.Adder(1, x_wires=(), mod=1, work_wires=0)
TypeError: object of type 'int' has no len()
```

Now, we have it properly handled,
```python
>>> qml.Adder(1, x_wires=(), mod=1, work_wires=0)
Adder(wires=[0])
>>> qml.Adder(1, x_wires=(), mod=1, work_wires=None)
ValueError: Adder: wrong number of wires. At least one wire has to be given.
```

**Description of the Change:**

**Source Code**

Change the faulty logic of `work_wires = work_wires or ()` which will
fail if `work_wires=0`.

**Test Suite**

Added tests to templates that allow single work wires, ensuring
`work_wires=0` can be passed.

**Benefits:** Fixes possibility of falsy zero logic in `work_wires`.

**Possible Drawbacks:** None.

[[sc-80454](https://app.shortcut.com/xanaduai/story/80454)]

---------

Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
  • Loading branch information
andrijapau and mudit2812 authored Dec 16, 2024
1 parent 4555833 commit 579e70d
Show file tree
Hide file tree
Showing 12 changed files with 34 additions and 28 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ same information.

* The `Wires` object throws a `TypeError` if `wires=None`.
[(#6713)](https://github.com/PennyLaneAI/pennylane/pull/6713)
[(#6720)](https://github.com/PennyLaneAI/pennylane/pull/6720)

* The `qml.Hermitian` class no longer checks that the provided matrix is hermitian.
The reason for this removal is to allow for faster execution and avoid incompatibilities with `jax.jit`.
Expand Down
4 changes: 2 additions & 2 deletions pennylane/ops/qubit/state_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ class BasisState(StatePrepBase):
[0.+0.j 0.+0.j 0.+0.j 1.+0.j]
"""

def __init__(self, state, wires, id=None):
def __init__(self, state, wires: WiresLike, id=None):

wires = Wires(wires)
if isinstance(state, list):
state = qml.math.stack(state)

Expand All @@ -86,7 +87,6 @@ def __init__(self, state, wires, id=None):
bin = 2 ** math.arange(len(wires))[::-1]
state = qml.math.where((state & bin) > 0, 1, 0)

wires = Wires(wires)
shape = qml.math.shape(state)

if len(shape) != 1:
Expand Down
11 changes: 3 additions & 8 deletions pennylane/templates/subroutines/adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def __init__(
): # pylint: disable=too-many-arguments

x_wires = qml.wires.Wires(x_wires)
work_wires = work_wires or ()
work_wires = qml.wires.Wires(work_wires)
work_wires = qml.wires.Wires(() if work_wires is None else work_wires)

num_works_wires = len(work_wires)

Expand All @@ -125,15 +124,11 @@ def __init__(
f"with len(x_wires)={len(x_wires)} is {2 ** len(x_wires)}, but received {mod}."
)

all_wires = (
qml.wires.Wires(x_wires) + qml.wires.Wires(work_wires)
if work_wires
else qml.wires.Wires(x_wires)
)
all_wires = x_wires + work_wires

self.hyperparameters["k"] = k
self.hyperparameters["mod"] = mod
self.hyperparameters["work_wires"] = qml.wires.Wires(work_wires)
self.hyperparameters["work_wires"] = work_wires
self.hyperparameters["x_wires"] = x_wires

super().__init__(wires=all_wires, id=id)
Expand Down
4 changes: 2 additions & 2 deletions pennylane/templates/subroutines/mod_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ def __init__(
): # pylint: disable=too-many-arguments

output_wires = qml.wires.Wires(output_wires)
work_wires = work_wires or ()
work_wires = qml.wires.Wires(work_wires)
work_wires = qml.wires.Wires(() if work_wires is None else work_wires)

if len(work_wires) == 0:
raise ValueError("Work wires must be specified for ModExp")

Expand Down
9 changes: 4 additions & 5 deletions pennylane/templates/subroutines/multiplier.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def __init__(
): # pylint: disable=too-many-arguments

x_wires = qml.wires.Wires(x_wires)
work_wires = work_wires or ()
work_wires = qml.wires.Wires(work_wires)
work_wires = qml.wires.Wires(() if work_wires is None else work_wires)
if len(work_wires) == 0:
raise ValueError("Work wires must be specified for Multiplier")

Expand All @@ -148,9 +147,9 @@ def __init__(

self.hyperparameters["k"] = k
self.hyperparameters["mod"] = mod
self.hyperparameters["work_wires"] = qml.wires.Wires(work_wires)
self.hyperparameters["x_wires"] = qml.wires.Wires(x_wires)
all_wires = qml.wires.Wires(x_wires) + qml.wires.Wires(work_wires)
self.hyperparameters["work_wires"] = work_wires
self.hyperparameters["x_wires"] = x_wires
all_wires = x_wires + work_wires
super().__init__(wires=all_wires, id=id)

@property
Expand Down
3 changes: 1 addition & 2 deletions pennylane/templates/subroutines/out_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ def __init__(
x_wires = qml.wires.Wires(x_wires)
y_wires = qml.wires.Wires(y_wires)
output_wires = qml.wires.Wires(output_wires)
work_wires = work_wires or ()
work_wires = qml.wires.Wires(work_wires)
work_wires = qml.wires.Wires(() if work_wires is None else work_wires)

num_work_wires = len(work_wires)

Expand Down
3 changes: 1 addition & 2 deletions pennylane/templates/subroutines/out_multiplier.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ def __init__(
x_wires = qml.wires.Wires(x_wires)
y_wires = qml.wires.Wires(y_wires)
output_wires = qml.wires.Wires(output_wires)
work_wires = work_wires or ()
work_wires = qml.wires.Wires(work_wires)
work_wires = qml.wires.Wires(() if work_wires is None else work_wires)

num_work_wires = len(work_wires)

Expand Down
5 changes: 2 additions & 3 deletions pennylane/templates/subroutines/out_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ def __init__(

registers_wires = [*input_registers, output_wires]

work_wires = work_wires or ()
work_wires = qml.wires.Wires(work_wires)
work_wires = qml.wires.Wires(() if work_wires is None else work_wires)
num_work_wires = len(work_wires)
if mod is None:
mod = 2 ** len(registers_wires[-1])
Expand Down Expand Up @@ -302,7 +301,7 @@ def __init__(

self.hyperparameters["polynomial_function"] = polynomial_function
self.hyperparameters["mod"] = mod
self.hyperparameters["work_wires"] = qml.wires.Wires(work_wires)
self.hyperparameters["work_wires"] = work_wires

wires_vars = [len(w) for w in registers_wires[:-1]]

Expand Down
3 changes: 1 addition & 2 deletions pennylane/templates/subroutines/phase_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ def __init__(
self, k, x_wires: WiresLike, mod=None, work_wire: WiresLike = (), id=None
): # pylint: disable=too-many-arguments

work_wire = work_wire or ()
work_wire = qml.wires.Wires(work_wire)
work_wire = qml.wires.Wires(() if work_wire is None else work_wire)
x_wires = qml.wires.Wires(x_wires)

num_work_wires = len(work_wire)
Expand Down
3 changes: 1 addition & 2 deletions pennylane/templates/subroutines/qrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ def __init__(
control_wires = qml.wires.Wires(control_wires)
target_wires = qml.wires.Wires(target_wires)

work_wires = work_wires or ()
work_wires = qml.wires.Wires(work_wires)
work_wires = qml.wires.Wires(() if work_wires is None else work_wires)

self.hyperparameters["bitstrings"] = tuple(bitstrings)
self.hyperparameters["control_wires"] = control_wires
Expand Down
10 changes: 10 additions & 0 deletions tests/templates/test_subroutines/test_phase_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def test_standard_validity_Phase_Adder():
qml.ops.functions.assert_valid(op)


def test_falsy_zero_as_work_wire():
"""Test that work wire is not treated as a falsy zero."""
k = 6
mod = 11
x_wires = [1, 2, 3, 4]
work_wire = 0
op = qml.PhaseAdder(k, x_wires=x_wires, mod=mod, work_wire=work_wire)
qml.ops.functions.assert_valid(op)


def test_add_k_fourier():
"""Test the private _add_k_fourier function."""

Expand Down
6 changes: 6 additions & 0 deletions tests/templates/test_subroutines/test_qrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def test_assert_valid_qrom():
qml.ops.functions.assert_valid(op)


def test_falsy_zero_as_work_wire():
"""Test that work wire is not treated as a falsy zero."""
op = qml.QROM(["1", "0", "0", "1"], control_wires=[1, 2], target_wires=[3], work_wires=0)
qml.ops.functions.assert_valid(op)


class TestQROM:
"""Test the qml.QROM template."""

Expand Down

0 comments on commit 579e70d

Please sign in to comment.