Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add qubit discard/measure methods #580

Merged
merged 5 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions guppylang/prelude/quantum.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,84 +26,128 @@ def __new__() -> "qubit":
reset(q)
return q

@guppy
@no_type_check
def measure(self: "qubit" @ owned) -> bool:
return measure(self)

@guppy
@no_type_check
def measure_return(self: "qubit") -> bool:
return measure_return(self)

@guppy
@no_type_check
def measure_reset(self: "qubit") -> bool:
"""Projective measure and reset without discarding the qubit."""
res = self.measure_return()
if res:
x(self)
return res

@guppy
@no_type_check
def discard(self: "qubit" @ owned) -> None:
discard(self)


@guppy.hugr_op(quantum_op("H"))
@no_type_check
def h(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("CZ"))
@no_type_check
def cz(control: qubit, target: qubit) -> None: ...


@guppy.hugr_op(quantum_op("CY"))
@no_type_check
def cy(control: qubit, target: qubit) -> None: ...


@guppy.hugr_op(quantum_op("CX"))
@no_type_check
def cx(control: qubit, target: qubit) -> None: ...


@guppy.hugr_op(quantum_op("T"))
@no_type_check
def t(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("S"))
@no_type_check
def s(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("X"))
@no_type_check
def x(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("Y"))
@no_type_check
def y(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("Z"))
@no_type_check
def z(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("Tdg"))
@no_type_check
def tdg(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("Sdg"))
@no_type_check
def sdg(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("ZZMax", ext=HSERIES_EXTENSION))
@no_type_check
def zz_max(q1: qubit, q2: qubit) -> None: ...


@guppy.custom(RotationCompiler("Rz"))
@no_type_check
def rz(q: qubit, angle: angle) -> None: ...


@guppy.custom(RotationCompiler("Rx"))
@no_type_check
def rx(q: qubit, angle: angle) -> None: ...


@guppy.custom(RotationCompiler("Ry"))
@no_type_check
def ry(q: qubit, angle: angle) -> None: ...


@guppy.custom(RotationCompiler("CRz"))
@no_type_check
def crz(control: qubit, target: qubit, angle: angle) -> None: ...


@guppy.hugr_op(quantum_op("Toffoli"))
@no_type_check
def toffoli(control1: qubit, control2: qubit, target: qubit) -> None: ...


@guppy.hugr_op(quantum_op("QAlloc"))
@no_type_check
def dirty_qubit() -> qubit: ...


@guppy.custom(MeasureReturnCompiler())
@no_type_check
def measure_return(q: qubit) -> bool: ...


@guppy.hugr_op(quantum_op("QFree"))
@no_type_check
def discard(q: qubit @ owned) -> None: ...


Expand Down Expand Up @@ -131,6 +175,7 @@ def zz_phase(q1: qubit, q2: qubit, angle: angle) -> None:


@guppy.hugr_op(quantum_op("Reset"))
@no_type_check
def reset(q: qubit) -> None: ...


Expand All @@ -140,6 +185,7 @@ def reset(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("PhasedX", ext=HSERIES_EXTENSION))
@no_type_check
def _phased_x(q: qubit, angle1: float, angle2: float) -> None:
"""PhasedX operation from the hseries extension.

Expand All @@ -149,6 +195,7 @@ def _phased_x(q: qubit, angle1: float, angle2: float) -> None:


@guppy.hugr_op(quantum_op("ZZPhase", ext=HSERIES_EXTENSION))
@no_type_check
def _zz_phase(q1: qubit, q2: qubit, angle: float) -> None:
"""ZZPhase operation from the hseries extension.

Expand Down
59 changes: 41 additions & 18 deletions tests/integration/test_inout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_basic(validate):
def foo(q: qubit) -> None: ...

@guppy(module)
def test(q: qubit @owned) -> qubit:
def test(q: qubit @ owned) -> qubit:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it appears my formatter has been at all the decorators. I can undo this if necessary

foo(q)
return q

Expand All @@ -29,10 +29,10 @@ def test_mixed(validate):
module.load_all(quantum)

@guppy.declare(module)
def foo(q1: qubit, q2: qubit @owned) -> qubit: ...
def foo(q1: qubit, q2: qubit @ owned) -> qubit: ...

@guppy(module)
def test(q1: qubit @owned, q2: qubit @owned) -> tuple[qubit, qubit]:
def test(q1: qubit @ owned, q2: qubit @ owned) -> tuple[qubit, qubit]:
q2 = foo(q1, q2)
return q1, q2

Expand All @@ -47,7 +47,7 @@ def test_local(validate):
def foo(q: qubit) -> None: ...

@guppy(module)
def test(q: qubit @owned) -> qubit:
def test(q: qubit @ owned) -> qubit:
f = foo
f(q)
return q
Expand All @@ -63,7 +63,7 @@ def test_nested_calls(validate):
def foo(x: int, q: qubit) -> int: ...

@guppy(module)
def test(q: qubit @owned) -> tuple[int, qubit]:
def test(q: qubit @ owned) -> tuple[int, qubit]:
# This is legal since function arguments and tuples are evaluated left to right
return foo(foo(foo(0, q), q), q), q

Expand All @@ -86,13 +86,13 @@ def foo(q1: qubit, q2: qubit) -> None: ...
def bar(a: MyStruct) -> None: ...

@guppy(module)
def test1(a: MyStruct @owned) -> MyStruct:
def test1(a: MyStruct @ owned) -> MyStruct:
foo(a.q1, a.q2)
bar(a)
return a

@guppy(module)
def test2(a: MyStruct @owned) -> MyStruct:
def test2(a: MyStruct @ owned) -> MyStruct:
bar(a)
foo(a.q1, a.q2)
bar(a)
Expand All @@ -112,7 +112,7 @@ def foo(q: qubit) -> None: ...
def bar(q: qubit) -> bool: ...

@guppy(module)
def test(q1: qubit @owned, q2: qubit @owned, n: int) -> tuple[qubit, qubit]:
def test(q1: qubit @ owned, q2: qubit @ owned, n: int) -> tuple[qubit, qubit]:
i = 0
while i < n:
foo(q1)
Expand Down Expand Up @@ -162,13 +162,15 @@ class C:
def foo(a: A, x: int) -> None: ...

@guppy.declare(module)
def bar(y: float, b: B, c: C @owned) -> C: ...
def bar(y: float, b: B, c: C @ owned) -> C: ...

@guppy.declare(module)
def baz(c: C) -> None: ...

@guppy(module)
def test(a: A @owned, b: B @owned, c1: C @owned, c2: C @owned, x: bool) -> tuple[A, B, C, C]:
def test(
a: A @ owned, b: B @ owned, c1: C @ owned, c2: C @ owned, x: bool
) -> tuple[A, B, C, C]:
c1 = (foo, bar, baz)(a, b.x, c1.x, b, c1, c2)
if x:
c1 = ((foo, bar), baz)(a, b.x, c1.x, b, c1, c2)
Expand All @@ -191,7 +193,7 @@ def foo(q: qubit) -> None:
h(q)

@guppy(module)
def test(q: qubit @owned) -> qubit:
def test(q: qubit @ owned) -> qubit:
foo(q)
foo(q)
return q
Expand All @@ -208,7 +210,7 @@ def test(q: qubit) -> None:
pass

@guppy(module)
def main(q: qubit @owned) -> qubit:
def main(q: qubit @ owned) -> qubit:
test(q)
return q

Expand All @@ -224,7 +226,7 @@ def foo(q: qubit) -> None: ...

@guppy(module)
def test(
b: int, c: qubit, d: float, a: tuple[qubit, qubit], e: qubit @owned
b: int, c: qubit, d: float, a: tuple[qubit, qubit], e: qubit @ owned
) -> tuple[qubit, float]:
foo(c)
return e, b + d
Expand All @@ -241,7 +243,7 @@ class MyStruct:
q: qubit

@guppy.declare(module)
def use(q: qubit @owned) -> None: ...
def use(q: qubit @ owned) -> None: ...

@guppy(module)
def foo(s: MyStruct) -> None:
Expand All @@ -257,7 +259,7 @@ def swap(s: MyStruct, t: MyStruct) -> None:
s.q, t.q = t.q, s.q

@guppy(module)
def main(s: MyStruct @owned, t: MyStruct @owned) -> MyStruct:
def main(s: MyStruct @ owned, t: MyStruct @ owned) -> MyStruct:
foo(s)
swap(s, t)
bar(t)
Expand All @@ -276,10 +278,12 @@ class MyStruct:
q: qubit

@guppy.declare(module)
def use(q: qubit @owned) -> None: ...
def use(q: qubit @ owned) -> None: ...

@guppy(module)
def test(s: MyStruct, b: bool, n: int, q1: qubit @owned, q2: qubit @owned) -> None:
def test(
s: MyStruct, b: bool, n: int, q1: qubit @ owned, q2: qubit @ owned
) -> None:
use(s.q)
if b:
s.q = q1
Expand All @@ -299,7 +303,7 @@ def test(s: MyStruct, b: bool, n: int, q1: qubit @owned, q2: qubit @owned) -> No
return

@guppy(module)
def main(s: MyStruct @owned) -> MyStruct:
def main(s: MyStruct @ owned) -> MyStruct:
test(s, False, 5, qubit(), qubit())
return s

Expand All @@ -325,6 +329,7 @@ def bar(self: "MyStruct", b: bool) -> None:

validate(module.compile())


def test_subtype(validate):
module = GuppyModule("test")
module.load_all(quantum)
Expand All @@ -340,6 +345,7 @@ def main() -> qubit:

validate(module.compile())


def test_shadow_check(validate):
module = GuppyModule("test")

Expand All @@ -354,3 +360,20 @@ def main(i: qubit) -> None:
foo(i)

validate(module.compile())


def test_self_qubit(validate):
module = GuppyModule("test")
module.load(qubit)

@guppy(module)
def test() -> bool:
q0 = qubit()

result = q0.measure_reset()
q0.measure_return()
q0.measure()
qubit().discard()
return result

validate(module.compile())
Loading