Skip to content

Commit

Permalink
Classical testing
Browse files Browse the repository at this point in the history
  • Loading branch information
mpharrigan committed Aug 6, 2024
1 parent 6a4b920 commit dd910bf
Show file tree
Hide file tree
Showing 10 changed files with 816 additions and 20 deletions.
4 changes: 3 additions & 1 deletion dev_tools/qualtran_dev_tools/bloq_report_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from qualtran import Bloq, BloqExample
from qualtran.testing import (
BloqCheckResult,
check_bloq_example_classical_action,
check_bloq_example_decompose,
check_bloq_example_make,
check_bloq_example_qtyping,
Expand Down Expand Up @@ -69,7 +70,7 @@ def bloq_classes_with_no_examples(


IDCOLS = ['package', 'bloq_cls', 'name']
CHECKCOLS = ['make', 'decomp', 'counts', 'serialize', 'qtyping']
CHECKCOLS = ['make', 'decomp', 'counts', 'serialize', 'qtyping', 'classical']


def record_for_class_with_no_examples(k: Type[Bloq]) -> Dict[str, Any]:
Expand All @@ -89,6 +90,7 @@ def record_for_bloq_example(be: BloqExample) -> Dict[str, Any]:
'counts': check_equivalent_bloq_example_counts(be)[0],
'serialize': check_bloq_example_serializes(be)[0],
'qtyping': check_bloq_example_qtyping(be)[0],
'classical': check_bloq_example_classical_action(be)[0],
}
dur = time.perf_counter() - start
if dur > 1.0:
Expand Down
54 changes: 52 additions & 2 deletions qualtran/_infra/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

import abc
from enum import Enum
from typing import Any, Iterable, List, Sequence, Union
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union

import attrs
import numpy as np
Expand All @@ -73,6 +73,13 @@ def get_classical_domain(self) -> Iterable[Any]:
"""Yields all possible classical (computational basis state) values representable
by this type."""

@abc.abstractmethod
def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
"""Returns a random classical (computational basis state) value representable
by this type."""

@abc.abstractmethod
def to_bits(self, x) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""
Expand Down Expand Up @@ -151,6 +158,11 @@ def num_qubits(self):
def get_classical_domain(self) -> Iterable[int]:
yield from (0, 1)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.choice([0, 1], size=size)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not (val == 0 or val == 1):
raise ValueError(f"Bad {self} value {val} in {debug_str}")
Expand Down Expand Up @@ -202,6 +214,11 @@ def from_bits(self, bits: Sequence[int]) -> int:
def is_symbolic(self) -> bool:
return is_symbolic(self.bitsize)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
raise TypeError(f"Ambiguous domain for {self}. Please use a more specific type.")

def assert_valid_classical_val(self, val, debug_str: str = 'val'):
pass

Expand Down Expand Up @@ -250,6 +267,11 @@ def from_bits(self, bits: Sequence[int]) -> int:
)
return ~x if sign else x

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.integers(-(2 ** (self.bitsize - 1)), 2 ** (self.bitsize - 1), size=size)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
raise ValueError(f"{debug_str} should be an integer, not {val!r}")
Expand Down Expand Up @@ -308,6 +330,11 @@ def get_classical_domain(self) -> Iterable[int]:
max_val = 1 << (self.bitsize - 1)
return range(-max_val + 1, max_val)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
raise NotImplementedError()

def assert_valid_classical_val(self, val, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
raise ValueError(f"{debug_str} should be an integer, not {val!r}")
Expand Down Expand Up @@ -341,6 +368,11 @@ def is_symbolic(self) -> bool:
def get_classical_domain(self) -> Iterable[Any]:
return range(2**self.bitsize)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.integers(2**self.bitsize, size=size)

def to_bits(self, x: int) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""
self.assert_valid_classical_val(x)
Expand Down Expand Up @@ -486,13 +518,21 @@ def get_classical_domain(self) -> Iterable[Any]:
return range(0, self.iteration_length)
raise ValueError(f'Classical Domain not defined for expression: {self.iteration_length}')

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.integers(0, self.iteration_length, size=size)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
raise ValueError(f"{debug_str} should be an integer, not {val!r}")
if val < 0:
raise ValueError(f"Negative classical value encountered in {debug_str}")
if val >= self.iteration_length:
raise ValueError(f"Too-large classical value encountered in {debug_str}")
raise ValueError(
f"Too-large classical value encountered in {debug_str}: "
f"{val} >= {self.iteration_length}"
)

def to_bits(self, x: int) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""
Expand Down Expand Up @@ -745,6 +785,11 @@ def _assert_valid_classical_val(self, val: Union[float, Fxp], debug_str: str = '
f"{debug_str}={val} cannot be accurately represented using Fxp {fxp_val}"
)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
raise NotImplementedError()


@attrs.frozen
class QMontgomeryUInt(QDType):
Expand Down Expand Up @@ -792,6 +837,11 @@ def to_bits(self, x: int) -> List[int]:
def from_bits(self, bits: Sequence[int]) -> int:
raise NotImplementedError(f"from_bits not implemented for {self}")

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.integers(2**self.bitsize, size=size)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
raise ValueError(f"{debug_str} should be an integer, not {val!r}")
Expand Down
30 changes: 30 additions & 0 deletions qualtran/bloqs/factoring/_big_int_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import SupportsInt, Union

import attrs
import sympy


def int_or_expr(x: Union[SupportsInt, sympy.Expr]) -> Union[int, sympy.Expr]:
if isinstance(x, sympy.Expr):
return x
return int(x)


def python_int_field():
"""An attrs field that requires an infinite-precision Python integer."""
return attrs.field(
validator=attrs.validators.instance_of((int, sympy.Expr)), converter=int_or_expr
)
14 changes: 8 additions & 6 deletions qualtran/bloqs/factoring/mod_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from functools import cached_property
from typing import Dict, Optional, Set, Tuple, Union

import attrs
import attrs.validators
import numpy as np
import sympy
from attrs import frozen
from attrs import field, frozen

from qualtran import (
Bloq,
Expand All @@ -33,6 +33,7 @@
SoquetT,
)
from qualtran.bloqs.basic_gates import IntState
from qualtran.bloqs.factoring._big_int_helpers import int_or_expr, python_int_field
from qualtran.bloqs.factoring.mod_mul import CtrlModMul
from qualtran.drawing import Text, WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
Expand Down Expand Up @@ -64,10 +65,10 @@ class ModExp(Bloq):
Gidney and Ekerå. 2019.
"""

base: Union[int, sympy.Expr]
mod: Union[int, sympy.Expr]
exp_bitsize: Union[int, sympy.Expr]
x_bitsize: Union[int, sympy.Expr]
base: Union[int, sympy.Expr] = python_int_field()
mod: Union[int, sympy.Expr] = python_int_field()
exp_bitsize: Union[int, sympy.Expr] = field()
x_bitsize: Union[int, sympy.Expr] = field()

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -121,6 +122,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
}

def on_classical_vals(self, exponent: int):
exponent = int_or_expr(exponent)
return {'exponent': exponent, 'x': (self.base**exponent) % self.mod}

def wire_symbol(
Expand Down
12 changes: 8 additions & 4 deletions qualtran/bloqs/factoring/mod_exp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
import pytest
import sympy

import qualtran.testing as qlt_testing
from qualtran import Bloq
from qualtran.bloqs.bookkeeping import Join, Split
from qualtran.bloqs.factoring.mod_exp import _modexp, _modexp_symb, ModExp
from qualtran.bloqs.factoring.mod_exp import _modexp, _modexp_small, _modexp_symb, ModExp
from qualtran.bloqs.factoring.mod_mul import CtrlModMul
from qualtran.drawing import Text
from qualtran.resource_counting import SympySymbolAllocator
from qualtran.testing import execute_notebook


def test_mod_exp_consistent_classical():
Expand Down Expand Up @@ -92,6 +92,10 @@ def test_mod_exp_t_complexity():
assert tcomp.t > 0


def test_modexp_small(bloq_autotester):
bloq_autotester(_modexp_small)


def test_modexp(bloq_autotester):
bloq_autotester(_modexp)

Expand All @@ -102,9 +106,9 @@ def test_modexp_symb(bloq_autotester):

@pytest.mark.notebook
def test_intro_notebook():
execute_notebook('factoring-via-modexp')
qlt_testing.execute_notebook('factoring-via-modexp')


@pytest.mark.notebook
def test_notebook():
execute_notebook('mod_exp')
qlt_testing.execute_notebook('mod_exp')
12 changes: 7 additions & 5 deletions qualtran/bloqs/factoring/mod_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import attrs
import numpy as np
import sympy
from attrs import frozen
from attrs import field, frozen

from qualtran import (
Bloq,
Expand All @@ -33,6 +33,7 @@
)
from qualtran.bloqs.arithmetic.addition import AddK
from qualtran.bloqs.basic_gates import CNOT, CSwap, XGate
from qualtran.bloqs.factoring._big_int_helpers import int_or_expr, python_int_field
from qualtran.bloqs.mod_arithmetic import CtrlScaleModAdd
from qualtran.drawing import Circle, directional_text_box, Text, WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
Expand All @@ -54,9 +55,9 @@ class CtrlModMul(Bloq):
x: The integer being multiplied
"""

k: Union[int, sympy.Expr]
mod: Union[int, sympy.Expr]
bitsize: Union[int, sympy.Expr]
k: Union[int, sympy.Expr] = python_int_field()
mod: Union[int, sympy.Expr] = python_int_field()
bitsize: Union[int, sympy.Expr] = field()

def __attrs_post_init__(self):
if isinstance(self.k, sympy.Expr):
Expand Down Expand Up @@ -104,11 +105,12 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
k = ssa.new_symbol('k')
return {(self._Add(k=k), 2), (CSwap(self.bitsize), 1)}

def on_classical_vals(self, ctrl, x) -> Dict[str, ClassicalValT]:
def on_classical_vals(self, ctrl: int, x: int) -> Dict[str, ClassicalValT]:
if ctrl == 0:
return {'ctrl': ctrl, 'x': x}

assert ctrl == 1, ctrl
x = int_or_expr(x)
return {'ctrl': ctrl, 'x': (x * self.k) % self.mod}

def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
Expand Down
18 changes: 17 additions & 1 deletion qualtran/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

import qualtran.testing as qlt_testing
Expand Down Expand Up @@ -75,6 +75,21 @@ def assert_equivalent_bloq_example_counts_for_pytest(bloq_ex: BloqExample):
raise bce from bce


def assert_bloq_example_classical_action_for_pytest(bloq_ex: BloqExample):
rng = np.random.default_rng(seed=52)
try:
qlt_testing.assert_bloq_example_classical_action(bloq_ex, rng=rng)
except qlt_testing.BloqCheckException as bce:
if bce.check_result in [
qlt_testing.BloqCheckResult.UNVERIFIED,
qlt_testing.BloqCheckResult.NA,
qlt_testing.BloqCheckResult.MISSING,
]:
pytest.skip(bce.msg)

raise bce from bce


def assert_bloq_example_serializes_for_pytest(bloq_ex: BloqExample):
if bloq_ex.name in [
'prep_sparse',
Expand Down Expand Up @@ -149,6 +164,7 @@ def assert_bloq_example_qtyping_for_pytest(bloq_ex: BloqExample):
('counts', assert_equivalent_bloq_example_counts_for_pytest),
('serialize', assert_bloq_example_serializes_for_pytest),
('qtyping', assert_bloq_example_qtyping_for_pytest),
('classical', assert_bloq_example_classical_action_for_pytest),
]


Expand Down
Loading

0 comments on commit dd910bf

Please sign in to comment.