Skip to content

Commit

Permalink
Modify KaliskiModInverse to support zero (#1486)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoureldinYosri authored Nov 5, 2024
1 parent 58e6b71 commit 0c931f5
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 34 deletions.
161 changes: 128 additions & 33 deletions qualtran/bloqs/mod_arithmetic/mod_division.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import numpy as np
import sympy
from attrs import evolve, frozen
from attrs import evolve, field, frozen

from qualtran import (
Bloq,
Expand Down Expand Up @@ -65,16 +65,23 @@ def signature(self) -> 'Signature':
Register('v', QMontgomeryUInt(self.bitsize)),
Register('m', QBit()),
Register('f', QBit()),
Register('is_terminal', QBit()),
]
)

def on_classical_vals(self, v: int, m: int, f: int) -> Dict[str, 'ClassicalValT']:
def on_classical_vals(
self, v: int, m: int, f: int, is_terminal: int
) -> Dict[str, 'ClassicalValT']:
print('here')
assert False
m ^= f & (v == 0)
assert is_terminal == 0
is_terminal ^= m
f ^= m
return {'v': v, 'm': m, 'f': f}
return {'v': v, 'm': m, 'f': f, 'is_terminal': is_terminal}

def build_composite_bloq(
self, bb: 'BloqBuilder', v: Soquet, m: Soquet, f: Soquet
self, bb: 'BloqBuilder', v: Soquet, m: Soquet, f: Soquet, is_terminal: Soquet
) -> Dict[str, 'SoquetT']:
if is_symbolic(self.bitsize):
raise DecomposeTypeError(f'symbolic decomposition is not supported for {self}')
Expand All @@ -89,7 +96,8 @@ def build_composite_bloq(
f = ctrls[-1]
v = bb.join(v_arr)
m, f = bb.add(CNOT(), ctrl=m, target=f)
return {'v': v, 'm': m, 'f': f}
m, is_terminal = bb.add(CNOT(), ctrl=m, target=is_terminal)
return {'v': v, 'm': m, 'f': f, 'is_terminal': is_terminal}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
if is_symbolic(self.bitsize):
Expand Down Expand Up @@ -408,16 +416,27 @@ def signature(self) -> 'Signature':
Register('s', QMontgomeryUInt(self.bitsize)),
Register('m', QBit()),
Register('f', QBit()),
Register('is_terminal', QBit()),
]
)

def build_composite_bloq(
self, bb: 'BloqBuilder', u: Soquet, v: Soquet, r: Soquet, s: Soquet, m: Soquet, f: Soquet
self,
bb: 'BloqBuilder',
u: Soquet,
v: Soquet,
r: Soquet,
s: Soquet,
m: Soquet,
f: Soquet,
is_terminal: Soquet,
) -> Dict[str, 'SoquetT']:
a = bb.allocate(1)
b = bb.allocate(1)

v, m, f = bb.add(_KaliskiIterationStep1(self.bitsize), v=v, m=m, f=f)
v, m, f, is_terminal = bb.add(
_KaliskiIterationStep1(self.bitsize), v=v, m=m, f=f, is_terminal=is_terminal
)
u, v, b, a, m, f = bb.add(
_KaliskiIterationStep2(self.bitsize), u=u, v=v, b=b, a=a, m=m, f=f
)
Expand All @@ -434,7 +453,7 @@ def build_composite_bloq(

bb.free(a)
bb.free(b)
return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f}
return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f, 'is_terminal': is_terminal}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {
Expand All @@ -447,7 +466,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
}

def on_classical_vals(
self, u: int, v: int, r: int, s: int, m: int, f: int
self, u: int, v: int, r: int, s: int, m: int, f: int, is_terminal: int
) -> Dict[str, 'ClassicalValT']:
"""This is the Kaliski algorithm as described in Fig7 of https://arxiv.org/pdf/2001.09580.
Expand All @@ -456,6 +475,7 @@ def on_classical_vals(
of `f` and `m`.
"""
assert m == 0
is_terminal = f == 1 and v == 0
if f == 0:
# When `f = 0` this means that the algorithm is nearly over and that we just need to
# double the value of `r`.
Expand Down Expand Up @@ -484,7 +504,7 @@ def on_classical_vals(
if swap:
u, v = v, u
r, s = s, r
return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f}
return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f, 'is_terminal': is_terminal}


@frozen
Expand All @@ -504,6 +524,7 @@ def signature(self) -> 'Signature':
Register('s', QMontgomeryUInt(self.bitsize)),
Register('m', QAny(2 * self.bitsize)),
Register('f', QBit()),
Register('terminal_condition', QAny(2 * self.bitsize)),
]
)

Expand All @@ -512,17 +533,33 @@ def _kaliski_iteration(self):
return _KaliskiIteration(self.bitsize, self.mod)

def build_composite_bloq(
self, bb: 'BloqBuilder', u: Soquet, v: Soquet, r: Soquet, s: Soquet, m: Soquet, f: Soquet
self,
bb: 'BloqBuilder',
u: Soquet,
v: Soquet,
r: Soquet,
s: Soquet,
m: Soquet,
f: Soquet,
terminal_condition: Soquet,
) -> Dict[str, 'SoquetT']:
f = bb.add(XGate(), q=f)
u = bb.add(XorK(QMontgomeryUInt(self.bitsize), self.mod), x=u)
s = bb.add(XorK(QMontgomeryUInt(self.bitsize), 1), x=s)

m_arr = bb.split(m)
terminal_condition_arr = bb.split(terminal_condition)

for i in range(2 * self.bitsize):
u, v, r, s, m_arr[i], f = bb.add(
self._kaliski_iteration, u=u, v=v, r=r, s=s, m=m_arr[i], f=f
u, v, r, s, m_arr[i], f, terminal_condition_arr[i] = bb.add(
self._kaliski_iteration,
u=u,
v=v,
r=r,
s=s,
m=m_arr[i],
f=f,
is_terminal=terminal_condition_arr[i],
)

r = bb.add(BitwiseNot(QMontgomeryUInt(self.bitsize)), x=r)
Expand All @@ -531,8 +568,43 @@ def build_composite_bloq(
u = bb.add(XorK(QMontgomeryUInt(self.bitsize), 1), x=u)
s = bb.add(XorK(QMontgomeryUInt(self.bitsize), self.mod), x=s)

# This is an extra step not present in the original Kaliski algorithm in order to
# handle the case of x=0. The invariant of the Kaliski algorithm is that that end of the
# algorithm u=1, s=0, r=mod inverse. This happens for all cases where the modular inverse
# exists (i.e. gcd(x, mod) = 1).
# The case where the input is zero is important. Although mathematically the inverse
# doesn't exist. For the bloq to be unitary it needs to map zero to itself.
# When the input is zero, the terminal values of the registers are r=mod, u=v=mod^1=mod-1
# (assuming odd modulus).
# So we clean those registers conditioned on the first terminal qubit which is set
# if and only if the input is zero.
terminal_condition_arr[0], r = bb.add(
XorK(QMontgomeryUInt(self.bitsize), self.mod).controlled(),
ctrl=terminal_condition_arr[0],
x=r,
)
terminal_condition_arr[0], u = bb.add(
XorK(QMontgomeryUInt(self.bitsize), self.mod - 1).controlled(),
ctrl=terminal_condition_arr[0],
x=u,
)
terminal_condition_arr[0], s = bb.add(
XorK(QMontgomeryUInt(self.bitsize), self.mod - 1).controlled(),
ctrl=terminal_condition_arr[0],
x=s,
)

m = bb.join(m_arr)
return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f}
terminal_condition = bb.join(terminal_condition_arr)
return {
'u': u,
'v': v,
'r': r,
's': s,
'm': m,
'f': f,
'terminal_condition': terminal_condition,
}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {
Expand All @@ -542,6 +614,8 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
XGate(): 1,
XorK(QMontgomeryUInt(self.bitsize), self.mod): 2,
XorK(QMontgomeryUInt(self.bitsize), 1): 2,
XorK(QMontgomeryUInt(self.bitsize), self.mod).controlled(): 1,
XorK(QMontgomeryUInt(self.bitsize), self.mod - 1).controlled(): 2,
}


Expand Down Expand Up @@ -575,7 +649,7 @@ class KaliskiModInverse(Bloq):
"""

bitsize: 'SymbolicInt'
mod: 'SymbolicInt'
mod: 'SymbolicInt' = field(validator=lambda _, __, v: is_symbolic(v) or v % 2 == 1)
uncompute: bool = False

@cached_property
Expand All @@ -584,22 +658,25 @@ def signature(self) -> 'Signature':
return Signature(
[
Register('x', QMontgomeryUInt(self.bitsize)),
Register('m', QAny(2 * self.bitsize), side=side),
Register('junk', QAny(4 * self.bitsize), side=side),
]
)

def build_composite_bloq(
self, bb: 'BloqBuilder', x: Soquet, m: Optional[Soquet] = None, f: Optional[Soquet] = None
self, bb: 'BloqBuilder', x: Soquet, junk: Optional[Soquet] = None
) -> Dict[str, 'SoquetT']:
u = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize))
r = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize))
s = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize))
f = bb.allocate(1)

if self.uncompute:
assert m is not None
u, x, r, s, m, f = cast(
Tuple[Soquet, Soquet, Soquet, Soquet, Soquet, Soquet],
assert junk is not None
junk_arr = bb.split(junk)
m = bb.join(junk_arr[: 2 * self.bitsize])
terminal_condition = bb.join(junk_arr[2 * self.bitsize :])
u, x, r, s, m, f, terminal_condition = cast(
Tuple[Soquet, Soquet, Soquet, Soquet, Soquet, Soquet, Soquet],
bb.add_from(
_KaliskiModInverseImpl(self.bitsize, self.mod).adjoint(),
u=u,
Expand All @@ -608,47 +685,65 @@ def build_composite_bloq(
s=s,
m=m,
f=f,
terminal_condition=terminal_condition,
),
)
bb.free(u)
bb.free(r)
bb.free(s)
bb.free(m)
bb.free(f)
bb.free(terminal_condition)
return {'x': x}

m = bb.allocate(2 * self.bitsize)
u, v, x, s, m, f = bb.add_from(
_KaliskiModInverseImpl(self.bitsize, self.mod), u=u, v=x, r=r, s=s, m=m, f=f
terminal_condition = bb.allocate(2 * self.bitsize)
u, v, x, s, m, f, terminal_condition = cast(
Tuple[Soquet, Soquet, Soquet, Soquet, Soquet, Soquet, Soquet],
bb.add_from(
_KaliskiModInverseImpl(self.bitsize, self.mod),
u=u,
v=x,
r=r,
s=s,
m=m,
f=f,
terminal_condition=terminal_condition,
),
)

assert isinstance(u, Soquet)
assert isinstance(v, Soquet)
assert isinstance(s, Soquet)
assert isinstance(f, Soquet)
bb.free(u)
bb.free(v)
bb.free(s)
bb.free(f)
return {'x': x, 'm': m}
junk = bb.join(np.concatenate([bb.split(m), bb.split(terminal_condition)]))
return {'x': x, 'junk': junk}

def adjoint(self) -> 'KaliskiModInverse':
return evolve(self, uncompute=not self.uncompute)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return _KaliskiModInverseImpl(self.bitsize, self.mod).build_call_graph(ssa)

def on_classical_vals(self, x: int, m: int = 0) -> Dict[str, 'ClassicalValT']:
u, v, r, s, f = int(self.mod), x, 0, 1, 1
def on_classical_vals(self, x: int, junk: int = 0) -> Dict[str, 'ClassicalValT']:
mod = int(self.mod)
u, v, r, s, f = mod, x, 0, 1, 1
terminal_condition = m = 0
iteration = _KaliskiModInverseImpl(self.bitsize, self.mod)._kaliski_iteration
for _ in range(2 * int(self.bitsize)):
u, v, r, s, m_i, f = iteration.call_classically(u=u, v=v, r=r, s=s, m=0, f=f)
u, v, r, s, m_i, f, is_terminal = iteration.call_classically(
u=u, v=v, r=r, s=s, m=0, f=f, is_terminal=0
)
m = (m << 1) | m_i
assert u == 1
assert s == self.mod
terminal_condition = (terminal_condition << 1) | is_terminal
assert u == 1 or (x == 0 and u == mod)
assert s == self.mod or (x == 0 and s == 1)
assert f == 0
assert v == 0
return {'x': self.mod - r, 'm': m}
return {
'x': (self.mod - r) if r else 0,
'junk': m * 2 ** (2 * self.bitsize) + terminal_condition,
}


@bloq_example
Expand Down
15 changes: 14 additions & 1 deletion qualtran/bloqs/mod_arithmetic/mod_division_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,24 @@ def test_kaliski_mod_inverse_classical_action(bitsize, mod):
continue
x_montgomery = dtype.uint_to_montgomery(x, mod)
res = blq.call_classically(x=x_montgomery)
print(x, x_montgomery)
assert res == cblq.call_classically(x=x_montgomery)
assert len(res) == 2
assert res[0] == dtype.montgomery_inverse(x_montgomery, mod)
assert dtype.montgomery_product(int(res[0]), x_montgomery, mod) == R
assert blq.adjoint().call_classically(x=res[0], m=res[1]) == (x_montgomery,)
assert blq.adjoint().call_classically(x=res[0], junk=res[1]) == (x_montgomery,)


@pytest.mark.parametrize('bitsize', [5, 6])
@pytest.mark.parametrize('mod', [3, 5, 7, 11, 13, 15])
def test_kaliski_mod_inverse_classical_action_zero(bitsize, mod):
blq = KaliskiModInverse(bitsize, mod)
cblq = blq.decompose_bloq()
# When x = 0 the terminal condition is achieved at the first iteration, this corresponds to
# m_0 = is_terminal_0 = 1 and all other bits = 0.
junk = 2 ** (4 * bitsize - 1) + 2 ** (2 * bitsize - 1)
assert blq.call_classically(x=0) == cblq.call_classically(x=0) == (0, junk)
assert blq.adjoint().call_classically(x=0, junk=junk) == (0,)


@pytest.mark.parametrize('bitsize', [5, 6])
Expand Down

0 comments on commit 0c931f5

Please sign in to comment.