Skip to content

Commit

Permalink
Classical testing
Browse files Browse the repository at this point in the history
  • Loading branch information
mpharrigan committed Apr 26, 2024
1 parent cbc3727 commit 63cacbd
Show file tree
Hide file tree
Showing 10 changed files with 805 additions and 32 deletions.
15 changes: 5 additions & 10 deletions dev_tools/qualtran_dev_tools/bloq_report_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from qualtran import Bloq, BloqExample
from qualtran.testing import (
BloqCheckResult,
check_bloq_example_classical_action,
check_bloq_example_decompose,
check_bloq_example_make,
check_bloq_example_serialize,
Expand Down Expand Up @@ -67,19 +68,12 @@ def bloq_classes_with_no_examples(


IDCOLS = ['package', 'bloq_cls', 'name']
CHECKCOLS = ['make', 'decomp', 'counts', 'serialize', 'typing']
CHECKCOLS = ['make', 'decomp', 'counts', 'serialize', 'typing', 'classical']


def record_for_class_with_no_examples(k: Type[Bloq]) -> Dict[str, Any]:
return {
'bloq_cls': k.__name__,
'package': _get_package(k),
'name': '-',
'make': BloqCheckResult.MISSING,
'decomp': BloqCheckResult.MISSING,
'counts': BloqCheckResult.MISSING,
'serialize': BloqCheckResult.MISSING,
'typing': BloqCheckResult.MISSING,
return {'bloq_cls': k.__name__, 'package': _get_package(k), 'name': '-'} | {
checkcol: BloqCheckResult.MISSING for checkcol in CHECKCOLS
}


Expand All @@ -93,6 +87,7 @@ def record_for_bloq_example(be: BloqExample) -> Dict[str, Any]:
'counts': check_equivalent_bloq_example_counts(be)[0],
'serialize': check_bloq_example_serialize(be)[0],
'typing': check_connections_preserve_preserves_types(be)[0],
'classical': check_bloq_example_classical_action(be)[0],
}


Expand Down
50 changes: 49 additions & 1 deletion qualtran/_infra/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

import abc
from enum import Enum
from typing import Any, cast, Iterable, List, Sequence, Union
from typing import Any, cast, Iterable, List, Optional, Sequence, Tuple, Union

import attrs
import numpy as np
Expand All @@ -72,6 +72,13 @@ def get_classical_domain(self) -> Iterable[Any]:
"""Yields all possible classical (computational basis state) values representable
by this type."""

@abc.abstractmethod
def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
"""Returns a random classical (computational basis state) value representable
by this type."""

@abc.abstractmethod
def to_bits(self, x) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""
Expand Down Expand Up @@ -119,6 +126,11 @@ def num_qubits(self):
def get_classical_domain(self) -> Iterable[int]:
yield from (0, 1)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.choice([0, 1], size=size)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not (val == 0 or val == 1):
raise ValueError(f"Bad {self} value {val} in {debug_str}")
Expand Down Expand Up @@ -164,6 +176,11 @@ def from_bits(self, bits: Sequence[int]) -> int:
# TODO: Raise an error once usage of `QAny` is minimized across the library
return QUInt(self.bitsize).from_bits(bits)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
raise TypeError(f"Ambiguous domain for {self}. Please use a more specific type.")

def assert_valid_classical_val(self, val, debug_str: str = 'val'):
pass

Expand Down Expand Up @@ -203,6 +220,11 @@ def from_bits(self, bits: Sequence[int]) -> int:
x = QUInt(self.bitsize - 1).from_bits([1 - x if sign else x for x in bits[1:]])
return ~x if sign else x

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.integers(-(2 ** (self.bitsize - 1)), 2 ** (self.bitsize - 1), size=size)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
raise ValueError(f"{debug_str} should be an integer, not {val!r}")
Expand Down Expand Up @@ -255,6 +277,11 @@ def get_classical_domain(self) -> Iterable[int]:
max_val = 1 << (self.bitsize - 1)
return range(-max_val + 1, max_val)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
raise NotImplementedError()

def assert_valid_classical_val(self, val, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
raise ValueError(f"{debug_str} should be an integer, not {val!r}")
Expand Down Expand Up @@ -285,6 +312,12 @@ def num_qubits(self):
def get_classical_domain(self) -> Iterable[Any]:
return range(2**self.bitsize)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.integers(2**self.bitsize, size=size)
return range(2**self.bitsize)

def to_bits(self, x: int) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""
self.assert_valid_classical_val(x)
Expand Down Expand Up @@ -382,6 +415,11 @@ def get_classical_domain(self) -> Iterable[Any]:
return range(0, self.iteration_length)
raise ValueError(f'Classical Domain not defined for expression: {self.iteration_length}')

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.integers(0, self.iteration_length, size=size)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
raise ValueError(f"{debug_str} should be an integer, not {val!r}")
Expand Down Expand Up @@ -486,6 +524,11 @@ def _assert_valid_classical_val(self, val: Union[float, Fxp], debug_str: str = '
f"{debug_str}={val} cannot be accurately represented using Fxp {fxp_val}"
)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
raise NotImplementedError()

def assert_valid_classical_val(self, val: Union[float, Fxp], debug_str: str = 'val'):
# TODO: Asserting a valid value here opens a can of worms because classical data, except integers,
# is currently not propagated correctly through Bloqs
Expand Down Expand Up @@ -541,6 +584,11 @@ def to_bits(self, x: int) -> List[int]:
def from_bits(self, bits: Sequence[int]) -> int:
raise NotImplementedError(f"from_bits not implemented for {self}")

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.integers(2**self.bitsize, size=size)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
raise ValueError(f"{debug_str} should be an integer, not {val!r}")
Expand Down
30 changes: 30 additions & 0 deletions qualtran/bloqs/factoring/_big_int_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import SupportsInt, Union

import attrs
import sympy


def int_or_expr(x: Union[SupportsInt, sympy.Expr]) -> Union[int, sympy.Expr]:
if isinstance(x, sympy.Expr):
return x
return int(x)


def python_int_field():
"""An attrs field that requires an infinite-precision Python integer."""
return attrs.field(
validator=attrs.validators.instance_of((int, sympy.Expr)), converter=int_or_expr
)
14 changes: 8 additions & 6 deletions qualtran/bloqs/factoring/mod_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from functools import cached_property
from typing import Dict, Optional, Set, Union

import attrs
import attrs.validators
import numpy as np
import sympy
from attrs import frozen
from attrs import field, frozen

from qualtran import (
Bloq,
Expand All @@ -32,6 +32,7 @@
SoquetT,
)
from qualtran.bloqs.basic_gates import IntState
from qualtran.bloqs.factoring._big_int_helpers import int_or_expr, python_int_field
from qualtran.bloqs.factoring.mod_mul import CtrlModMul
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.resource_counting.generalizers import ignore_split_join
Expand Down Expand Up @@ -62,10 +63,10 @@ class ModExp(Bloq):
Gidney and Ekerå. 2019.
"""

base: Union[int, sympy.Expr]
mod: Union[int, sympy.Expr]
exp_bitsize: Union[int, sympy.Expr]
x_bitsize: Union[int, sympy.Expr]
base: Union[int, sympy.Expr] = python_int_field()
mod: Union[int, sympy.Expr] = python_int_field()
exp_bitsize: Union[int, sympy.Expr] = field()
x_bitsize: Union[int, sympy.Expr] = field()

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -119,6 +120,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
}

def on_classical_vals(self, exponent: int):
exponent = int_or_expr(exponent)
return {'exponent': exponent, 'x': (self.base**exponent) % self.mod}

def short_name(self) -> str:
Expand Down
12 changes: 8 additions & 4 deletions qualtran/bloqs/factoring/mod_exp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import pytest
import sympy

import qualtran.testing as qlt_testing
from qualtran import Bloq
from qualtran.bloqs.factoring.mod_exp import _modexp, _modexp_symb, ModExp
from qualtran.bloqs.factoring.mod_exp import _modexp, _modexp_small, _modexp_symb, ModExp
from qualtran.bloqs.factoring.mod_mul import CtrlModMul
from qualtran.bloqs.util_bloqs import Join, Split
from qualtran.resource_counting import SympySymbolAllocator
from qualtran.testing import execute_notebook


def test_mod_exp_consistent_classical():
Expand Down Expand Up @@ -85,6 +85,10 @@ def generalize(b: Bloq) -> Optional[Bloq]:
assert counts1 == counts2


def test_modexp_small(bloq_autotester):
bloq_autotester(_modexp_small)


def test_modexp(bloq_autotester):
bloq_autotester(_modexp)

Expand All @@ -95,9 +99,9 @@ def test_modexp_symb(bloq_autotester):

@pytest.mark.notebook
def test_intro_notebook():
execute_notebook('factoring-via-modexp')
qlt_testing.execute_notebook('factoring-via-modexp')


@pytest.mark.notebook
def test_notebook():
execute_notebook('mod_exp')
qlt_testing.execute_notebook('mod_exp')
12 changes: 7 additions & 5 deletions qualtran/bloqs/factoring/mod_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import attrs
import numpy as np
import sympy
from attrs import frozen
from attrs import field, frozen

from qualtran import (
Bloq,
Expand All @@ -33,6 +33,7 @@
)
from qualtran.bloqs.arithmetic.addition import SimpleAddConstant
from qualtran.bloqs.basic_gates import CNOT, CSwap, XGate
from qualtran.bloqs.factoring._big_int_helpers import int_or_expr, python_int_field
from qualtran.bloqs.factoring.mod_add import CtrlScaleModAdd
from qualtran.drawing import Circle, directional_text_box, WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
Expand All @@ -54,9 +55,9 @@ class CtrlModMul(Bloq):
x: The integer being multiplied
"""

k: Union[int, sympy.Expr]
mod: Union[int, sympy.Expr]
bitsize: Union[int, sympy.Expr]
k: Union[int, sympy.Expr] = python_int_field()
mod: Union[int, sympy.Expr] = python_int_field()
bitsize: Union[int, sympy.Expr] = field()

def __attrs_post_init__(self):
if isinstance(self.k, sympy.Expr):
Expand Down Expand Up @@ -101,11 +102,12 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
k = ssa.new_symbol('k')
return {(self._Add(k=k), 2), (CSwap(self.bitsize), 1)}

def on_classical_vals(self, ctrl, x) -> Dict[str, ClassicalValT]:
def on_classical_vals(self, ctrl: int, x: int) -> Dict[str, ClassicalValT]:
if ctrl == 0:
return {'ctrl': ctrl, 'x': x}

assert ctrl == 1, ctrl
x = int_or_expr(x)
return {'ctrl': ctrl, 'x': (x * self.k) % self.mod}

def short_name(self) -> str:
Expand Down
20 changes: 17 additions & 3 deletions qualtran/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

import qualtran.testing as qlt_testing
Expand Down Expand Up @@ -75,6 +75,21 @@ def assert_equivalent_bloq_example_counts_for_pytest(bloq_ex: BloqExample):
raise bce from bce


def assert_bloq_example_classical_action_for_pytest(bloq_ex: BloqExample):
rng = np.random.default_rng(seed=52)
try:
qlt_testing.assert_bloq_example_classical_action(bloq_ex, rng=rng)
except qlt_testing.BloqCheckException as bce:
if bce.check_result in [
qlt_testing.BloqCheckResult.UNVERIFIED,
qlt_testing.BloqCheckResult.NA,
qlt_testing.BloqCheckResult.MISSING,
]:
pytest.skip(bce.msg)

raise bce from bce


def assert_bloq_example_serialize_for_pytest(bloq_ex: BloqExample):
if bloq_ex.name in [
'prep_sparse',
Expand Down Expand Up @@ -112,8 +127,7 @@ def assert_bloq_example_typing_for_pytest(bloq_ex: BloqExample):
('make', assert_bloq_example_make_for_pytest),
('decompose', assert_bloq_example_decompose_for_pytest),
('counts', assert_equivalent_bloq_example_counts_for_pytest),
('serialization', assert_bloq_example_serialize_for_pytest),
('typing', assert_bloq_example_typing_for_pytest),
('classical', assert_bloq_example_classical_action_for_pytest),
]


Expand Down
Loading

0 comments on commit 63cacbd

Please sign in to comment.