Skip to content

Commit

Permalink
Add Decomposition of ECWindowAddR (#1477)
Browse files Browse the repository at this point in the history
* ECWindowAddR complete

* fix test cases

* Fix mypy errors

* Address comments

* Clear outputs of ecc notebook

* slight speedup on test by using very small bitsize

* keep sympy symbolic

* Fix nit

---------

Co-authored-by: Matthew Harrigan <mpharrigan@google.com>
  • Loading branch information
fpapa250 and mpharrigan authored Nov 4, 2024
1 parent b65ba6d commit 580ae65
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 59 deletions.
147 changes: 131 additions & 16 deletions qualtran/bloqs/factoring/ecc/ec_add_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,30 @@
from functools import cached_property
from typing import Dict, Optional, Tuple, Union

import numpy as np
import sympy
from attrs import frozen

from qualtran import Bloq, bloq_example, BloqDocSpec, QBit, QUInt, Register, Signature
from qualtran import (
Bloq,
bloq_example,
BloqBuilder,
BloqDocSpec,
QBit,
QMontgomeryUInt,
QUInt,
Register,
Signature,
Soquet,
SoquetT,
)
from qualtran.bloqs.data_loading import QROAMClean
from qualtran.drawing import Circle, Text, TextBox, WireSymbol
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics import is_symbolic, Shaped

from .ec_add import ECAdd
from .ec_point import ECPoint


Expand Down Expand Up @@ -113,33 +130,134 @@ class ECWindowAddR(Bloq):
Args:
n: The bitsize of the two registers storing the elliptic curve point
window_size: The number of bits in the window.
R: The elliptic curve point to add.
R: The elliptic curve point to add (NOT in montgomery form).
add_window_size: The number of bits in the ECAdd window.
mul_window_size: The number of bits in the modular multiplication window.
Registers:
ctrl: `window_size` control bits.
x: The x component of the input elliptic curve point of bitsize `n`.
y: The y component of the input elliptic curve point of bitsize `n`.
x: The x component of the input elliptic curve point of bitsize `n` in montgomery form.
y: The y component of the input elliptic curve point of bitsize `n` in montgomery form.
References:
[How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585).
Litinski. 2013. Section 1, eq. (3) and (4).
"""

n: int
window_size: int
R: ECPoint
add_window_size: int
mul_window_size: int = 1

@cached_property
def signature(self) -> 'Signature':
return Signature(
[
Register('ctrl', QBit(), shape=(self.window_size,)),
Register('ctrl', QBit(), shape=(self.add_window_size,)),
Register('x', QUInt(self.n)),
Register('y', QUInt(self.n)),
]
)

@cached_property
def qrom(self) -> QROAMClean:
if is_symbolic(self.n) or is_symbolic(self.add_window_size):
log_block_sizes = None
if is_symbolic(self.n) and not is_symbolic(self.add_window_size):
# We assume that bitsize is much larger than window_size
log_block_sizes = (0,)
return QROAMClean(
[
Shaped((2**self.add_window_size,)),
Shaped((2**self.add_window_size,)),
Shaped((2**self.add_window_size,)),
],
selection_bitsizes=(self.add_window_size,),
target_bitsizes=(self.n, self.n, self.n),
log_block_sizes=log_block_sizes,
)

cR = self.R
data_a, data_b, data_lam = [0], [0], [0]
mon_int = QMontgomeryUInt(self.n)
for _ in range(1, 2**self.add_window_size):
data_a.append(mon_int.uint_to_montgomery(int(cR.x), int(self.R.mod)))
data_b.append(mon_int.uint_to_montgomery(int(cR.y), int(self.R.mod)))
lam_num = (3 * cR.x**2 + cR.curve_a) % cR.mod
lam_denom = (2 * cR.y) % cR.mod
if lam_denom != 0:
lam = (lam_num * pow(lam_denom, -1, mod=cR.mod)) % cR.mod
else:
lam = 0
data_lam.append(mon_int.uint_to_montgomery(int(lam), int(self.R.mod)))
cR = cR + self.R

return QROAMClean(
[data_a, data_b, data_lam],
selection_bitsizes=(self.add_window_size,),
target_bitsizes=(self.n, self.n, self.n),
)

def build_composite_bloq(
self, bb: 'BloqBuilder', ctrl: 'SoquetT', x: 'Soquet', y: 'Soquet'
) -> Dict[str, 'SoquetT']:
ctrl = bb.join(np.array(ctrl))

ctrl, a, b, lam_r, *junk = bb.add(self.qrom, selection=ctrl)

a, b, x, y, lam_r = bb.add(
# TODO(https://github.com/quantumlib/Qualtran/issues/1476): make ECAdd accept SymbolicInt.
ECAdd(n=self.n, mod=int(self.R.mod), window_size=self.mul_window_size),
a=a,
b=b,
x=x,
y=y,
lam_r=lam_r,
)

if junk:
assert len(junk) == 3
ctrl = bb.add(
self.qrom.adjoint(),
selection=ctrl,
target0_=a,
target1_=b,
target2_=lam_r,
junk_target0_=junk[0],
junk_target1_=junk[1],
junk_target2_=junk[2],
)
else:
ctrl = bb.add(
self.qrom.adjoint(), selection=ctrl, target0_=a, target1_=b, target2_=lam_r
)

return {'ctrl': bb.split(ctrl), 'x': x, 'y': y}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {
self.qrom: 1,
# TODO(https://github.com/quantumlib/Qualtran/issues/1476): make ECAdd accept SymbolicInt.
ECAdd(self.n, int(self.R.mod), self.mul_window_size): 1,
self.qrom.adjoint(): 1,
}

def on_classical_vals(self, ctrl, x, y) -> Dict[str, Union['ClassicalValT', sympy.Expr]]:
# TODO(https://github.com/quantumlib/Qualtran/issues/1476): make ECAdd accept SymbolicInt.
A = ECPoint(
QMontgomeryUInt(self.n).montgomery_to_uint(int(x), int(self.R.mod)),
QMontgomeryUInt(self.n).montgomery_to_uint(int(y), int(self.R.mod)),
mod=self.R.mod,
curve_a=self.R.curve_a,
)
ctrls = QUInt(self.n).from_bits(ctrl)
result: ECPoint = A + (ctrls * self.R)
return {
'ctrl': ctrl,
'x': QMontgomeryUInt(self.n).uint_to_montgomery(int(result.x), int(self.R.mod)),
'y': QMontgomeryUInt(self.n).uint_to_montgomery(int(result.y), int(self.R.mod)),
}

def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
Expand All @@ -153,16 +271,13 @@ def wire_symbol(
return TextBox(f'$+{self.R.y}$')
raise ValueError(f'Unrecognized register name {reg.name}')

def __str__(self):
return f'ECWindowAddR({self.n=})'


@bloq_example
def _ec_window_add() -> ECWindowAddR:
n, p = sympy.symbols('n p')
Rx, Ry = sympy.symbols('Rx Ry')
ec_window_add = ECWindowAddR(n=n, window_size=3, R=ECPoint(Rx, Ry, mod=p))
return ec_window_add
def _ec_window_add_r_small() -> ECWindowAddR:
n = 16
P = ECPoint(2, 2, mod=7, curve_a=3)
ec_window_add_r_small = ECWindowAddR(n=n, R=P, add_window_size=4)
return ec_window_add_r_small


_EC_WINDOW_ADD_BLOQ_DOC = BloqDocSpec(bloq_cls=ECWindowAddR, examples=[_ec_window_add])
_EC_WINDOW_ADD_BLOQ_DOC = BloqDocSpec(bloq_cls=ECWindowAddR, examples=[_ec_window_add_r_small])
77 changes: 70 additions & 7 deletions qualtran/bloqs/factoring/ecc/ec_add_r_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,79 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from qualtran.bloqs.factoring.ecc.ec_add_r import _ec_add_r, _ec_add_r_small, _ec_window_add
import numpy as np
import pytest

import qualtran.testing as qlt_testing
from qualtran import QMontgomeryUInt, QUInt
from qualtran.bloqs.factoring.ecc.ec_add_r import (
_ec_add_r,
_ec_add_r_small,
_ec_window_add_r_small,
ECWindowAddR,
)
from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join

def test_ec_add_r(bloq_autotester):
bloq_autotester(_ec_add_r)
from .ec_add_r import ECWindowAddR
from .ec_point import ECPoint


def test_ec_add_r_small(bloq_autotester):
bloq_autotester(_ec_add_r_small)
@pytest.mark.parametrize('bloq', [_ec_add_r, _ec_add_r_small, _ec_window_add_r_small])
def test_ec_add_r(bloq_autotester, bloq):
bloq_autotester(bloq)


def test_ec_window_add(bloq_autotester):
bloq_autotester(_ec_window_add)
@pytest.mark.parametrize('a,b', [(15, 13), (0, 0)])
@pytest.mark.parametrize(
['n', 'window_size'],
[
(n, window_size)
for n in range(5, 8)
for window_size in range(1, n + 1)
if n % window_size == 0
],
)
def test_ec_window_add_r_bloq_counts(n, window_size, a, b):
p = 17
R = ECPoint(a, b, mod=p)
bloq = ECWindowAddR(n=n, R=R, add_window_size=window_size)
qlt_testing.assert_equivalent_bloq_counts(bloq, [ignore_alloc_free, ignore_split_join])


@pytest.mark.parametrize(
['n', 'm'], [(n, m) for n in range(4, 5) for m in range(1, n + 1) if n % m == 0]
)
@pytest.mark.parametrize('a,b', [(15, 13), (0, 0)])
@pytest.mark.parametrize('x,y', [(15, 13), (5, 8)])
@pytest.mark.parametrize('ctrl', [0, 1, 5])
def test_ec_window_add_r_classical(n, m, ctrl, x, y, a, b):
p = 17
R = ECPoint(a, b, mod=p)
x = QMontgomeryUInt(n).uint_to_montgomery(x, p)
y = QMontgomeryUInt(n).uint_to_montgomery(y, p)
ctrl = np.array(QUInt(m).to_bits(ctrl % (2**m)))
bloq = ECWindowAddR(n=n, R=R, add_window_size=m, mul_window_size=m)
ret1 = bloq.call_classically(ctrl=ctrl, x=x, y=y)
ret2 = bloq.decompose_bloq().call_classically(ctrl=ctrl, x=x, y=y)
for i, ret1_i in enumerate(ret1):
np.testing.assert_array_equal(ret1_i, ret2[i])


@pytest.mark.slow
@pytest.mark.parametrize(
['n', 'm'], [(n, m) for n in range(7, 9) for m in range(1, n + 1) if n % m == 0]
)
@pytest.mark.parametrize('a,b', [(15, 13), (0, 0)])
@pytest.mark.parametrize('x,y', [(15, 13), (5, 8)])
@pytest.mark.parametrize('ctrl', [0, 1, 5, 8])
def test_ec_window_add_r_classical_slow(n, m, ctrl, x, y, a, b):
p = 17
R = ECPoint(a, b, mod=p)
x = QMontgomeryUInt(n).uint_to_montgomery(x, p)
y = QMontgomeryUInt(n).uint_to_montgomery(y, p)
ctrl = np.array(QUInt(m).to_bits(ctrl % (2**m)))
bloq = ECWindowAddR(n=n, R=R, add_window_size=m, mul_window_size=m)
ret1 = bloq.call_classically(ctrl=ctrl, x=x, y=y)
ret2 = bloq.decompose_bloq().call_classically(ctrl=ctrl, x=x, y=y)
for i, ret1_i in enumerate(ret1):
np.testing.assert_array_equal(ret1_i, ret2[i])
51 changes: 44 additions & 7 deletions qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from functools import cached_property
from typing import Dict
from typing import Dict, Union

import numpy as np
import sympy
from attrs import frozen

Expand All @@ -34,7 +36,7 @@
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator

from .._factoring_shims import MeasureQFT
from .ec_add_r import ECAddR
from .ec_add_r import ECAddR, ECWindowAddR
from .ec_point import ECPoint


Expand All @@ -45,38 +47,73 @@ class ECPhaseEstimateR(Bloq):
This is used as a subroutine in `FindECCPrivateKey`. First, we phase-estimate the
addition of the base point $P$, then of the public key $Q$.
When the ellptic curve point addition window size is 1 we use the ECAddR bloq which has it's
own bespoke circuit; when it is greater than 1 we use the windowed circuit which uses
pre-computed classical additions loaded into the circuit.
Args:
n: The bitsize of the elliptic curve points' x and y registers.
point: The elliptic curve point to phase estimate against.
add_window_size: The number of bits in the ECAdd window.
mul_window_size: The number of bits in the modular multiplication window.
"""

n: int
point: ECPoint
add_window_size: int = 1
mul_window_size: int = 1

@cached_property
def signature(self) -> 'Signature':
return Signature([Register('x', QUInt(self.n)), Register('y', QUInt(self.n))])

@property
def ec_add(self) -> Union[functools.partial[ECAddR], functools.partial[ECWindowAddR]]:
if self.add_window_size == 1:
return functools.partial(ECAddR, n=self.n)
return functools.partial(
ECWindowAddR,
n=self.n,
add_window_size=self.add_window_size,
mul_window_size=self.mul_window_size,
)

@property
def num_windows(self) -> int:
return self.n // self.add_window_size

def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']:
if isinstance(self.n, sympy.Expr):
raise DecomposeTypeError("Cannot decompose symbolic `n`.")
ctrl = [bb.add(PlusState()) for _ in range(self.n)]
for i in range(self.n):
ctrl[i], x, y = bb.add(ECAddR(n=self.n, R=2**i * self.point), ctrl=ctrl[i], x=x, y=y)

if self.add_window_size == 1:
for i in range(self.n):
ctrl[i], x, y = bb.add(self.ec_add(R=2**i * self.point), ctrl=ctrl[i], x=x, y=y)
else:
ctrls = np.split(np.array(ctrl), self.num_windows)
for i in range(self.num_windows):
ctrls[i], x, y = bb.add(
self.ec_add(R=2 ** (self.add_window_size * i) * self.point),
ctrl=ctrls[i],
x=x,
y=y,
)
ctrl = np.concatenate(ctrls, axis=None)

bb.add(MeasureQFT(n=self.n), x=ctrl)
return {'x': x, 'y': y}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {ECAddR(n=self.n, R=self.point): self.n, MeasureQFT(n=self.n): 1}
return {self.ec_add(R=self.point): self.num_windows, MeasureQFT(n=self.n): 1}

def __str__(self) -> str:
return f'PE${self.point}$'


@bloq_example
def _ec_pe() -> ECPhaseEstimateR:
n, p = sympy.symbols('n p ')
n, p = sympy.symbols('n p')
Rx, Ry = sympy.symbols('R_x R_y')
ec_pe = ECPhaseEstimateR(n=n, point=ECPoint(Rx, Ry, mod=p))
return ec_pe
Expand All @@ -90,4 +127,4 @@ def _ec_pe_small() -> ECPhaseEstimateR:
return ec_pe_small


_EC_PE_BLOQ_DOC = BloqDocSpec(bloq_cls=ECPhaseEstimateR, examples=[_ec_pe])
_EC_PE_BLOQ_DOC = BloqDocSpec(bloq_cls=ECPhaseEstimateR, examples=[_ec_pe, _ec_pe_small])
Loading

0 comments on commit 580ae65

Please sign in to comment.