Skip to content

Commit

Permalink
Classical check
Browse files Browse the repository at this point in the history
  • Loading branch information
mpharrigan committed Mar 13, 2024
1 parent 9558f78 commit e5b3a3f
Show file tree
Hide file tree
Showing 10 changed files with 807 additions and 29 deletions.
13 changes: 5 additions & 8 deletions dev_tools/qualtran_dev_tools/bloq_report_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from qualtran import Bloq, BloqExample
from qualtran.testing import (
BloqCheckResult,
check_bloq_example_classical_action,
check_bloq_example_decompose,
check_bloq_example_make,
check_equivalent_bloq_example_counts,
Expand Down Expand Up @@ -65,17 +66,12 @@ def bloq_classes_with_no_examples(


IDCOLS = ['package', 'bloq_cls', 'name']
CHECKCOLS = ['make', 'decomp', 'counts']
CHECKCOLS = ['make', 'decomp', 'counts', 'classical']


def record_for_class_with_no_examples(k: Type[Bloq]) -> Dict[str, Any]:
return {
'bloq_cls': k.__name__,
'package': _get_package(k),
'name': '-',
'make': BloqCheckResult.MISSING,
'decomp': BloqCheckResult.MISSING,
'counts': BloqCheckResult.MISSING,
return {'bloq_cls': k.__name__, 'package': _get_package(k), 'name': '-'} | {
checkcol: BloqCheckResult.MISSING for checkcol in CHECKCOLS
}


Expand All @@ -87,6 +83,7 @@ def record_for_bloq_example(be: BloqExample) -> Dict[str, Any]:
'make': check_bloq_example_make(be)[0],
'decomp': check_bloq_example_decompose(be)[0],
'counts': check_equivalent_bloq_example_counts(be)[0],
'classical': check_bloq_example_classical_action(be)[0],
}


Expand Down
53 changes: 50 additions & 3 deletions qualtran/_infra/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"""

import abc
from typing import Any, Iterable, Union
from typing import Any, Iterable, Optional, Tuple, Union

import attrs
import numpy as np
Expand All @@ -70,6 +70,13 @@ def get_classical_domain(self) -> Iterable[Any]:
"""Yields all possible classical (computational basis state) values representable
by this type."""

@abc.abstractmethod
def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
"""Returns a random classical (computational basis state) value representable
by this type."""

@abc.abstractmethod
def assert_valid_classical_val(self, val: Any, debug_str: str = 'val'):
"""Raises an exception if `val` is not a valid classical value for this type.
Expand Down Expand Up @@ -106,6 +113,11 @@ def num_qubits(self):
def get_classical_domain(self) -> Iterable[int]:
yield from (0, 1)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.choice([0, 1], size=size)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not (val == 0 or val == 1):
raise ValueError(f"Bad {self} value {val} in {debug_str}")
Expand All @@ -128,6 +140,11 @@ def num_qubits(self):
def get_classical_domain(self) -> Iterable[Any]:
raise TypeError(f"Ambiguous domain for {self}. Please use a more specific type.")

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
raise TypeError(f"Ambiguous domain for {self}. Please use a more specific type.")

def assert_valid_classical_val(self, val, debug_str: str = 'val'):
pass

Expand All @@ -154,6 +171,11 @@ def num_qubits(self):
def get_classical_domain(self) -> Iterable[int]:
return range(-(2 ** (self.bitsize - 1)), 2 ** (self.bitsize - 1))

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.integers(-(2 ** (self.bitsize - 1)), 2 ** (self.bitsize - 1), size=size)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
raise ValueError(f"{debug_str} should be an integer, not {val!r}")
Expand Down Expand Up @@ -193,6 +215,11 @@ def num_qubits(self):
def get_classical_domain(self) -> Iterable[Any]:
raise NotImplementedError()

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
raise NotImplementedError()

def assert_valid_classical_val(self, val, debug_str: str = 'val'):
pass # TODO: implement

Expand All @@ -215,7 +242,12 @@ def num_qubits(self):
return self.bitsize

def get_classical_domain(self) -> Iterable[Any]:
return range(2 ** (self.bitsize))
return range(2**self.bitsize)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.integers(2**self.bitsize, size=size)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
Expand Down Expand Up @@ -301,6 +333,11 @@ def num_qubits(self):
def get_classical_domain(self) -> Iterable[Any]:
return range(0, self.iteration_length)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.integers(0, self.iteration_length, size=size)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
raise ValueError(f"{debug_str} should be an integer, not {val!r}")
Expand Down Expand Up @@ -366,6 +403,11 @@ def __attrs_post_init__(self):
def get_classical_domain(self) -> Iterable[Any]:
raise NotImplementedError()

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
raise NotImplementedError()

def assert_valid_classical_val(self, val, debug_str: str = 'val'):
pass # TODO: implement

Expand Down Expand Up @@ -405,7 +447,12 @@ def num_qubits(self):
return self.bitsize

def get_classical_domain(self) -> Iterable[Any]:
return range(2 ** (self.bitsize))
return range(2**self.bitsize)

def get_random_classical_vals(
self, rng: np.random.Generator, size: Optional[Union[int, Tuple[int, ...]]] = None
) -> Any:
return rng.integers(2**self.bitsize, size=size)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
Expand Down
30 changes: 30 additions & 0 deletions qualtran/bloqs/factoring/_big_int_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import SupportsInt, Union

import attrs
import sympy


def int_or_expr(x: Union[SupportsInt, sympy.Expr]) -> Union[int, sympy.Expr]:
if isinstance(x, sympy.Expr):
return x
return int(x)


def python_int_field():
"""An attrs field that requires an infinite-precision Python integer."""
return attrs.field(
validator=attrs.validators.instance_of((int, sympy.Expr)), converter=int_or_expr
)
14 changes: 8 additions & 6 deletions qualtran/bloqs/factoring/mod_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from functools import cached_property
from typing import Dict, Optional, Set, Union

import attrs
import attrs.validators
import numpy as np
import sympy
from attrs import frozen
from attrs import field, frozen

from qualtran import (
Bloq,
Expand All @@ -32,6 +32,7 @@
SoquetT,
)
from qualtran.bloqs.basic_gates import IntState
from qualtran.bloqs.factoring._big_int_helpers import int_or_expr, python_int_field
from qualtran.bloqs.factoring.mod_mul import CtrlModMul
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.resource_counting.generalizers import ignore_split_join
Expand Down Expand Up @@ -62,10 +63,10 @@ class ModExp(Bloq):
[arxiv:1905.09749](https://arxiv.org/abs/1905.09749). Gidney and Ekerå. 2019.
"""

base: Union[int, sympy.Expr]
mod: Union[int, sympy.Expr]
exp_bitsize: Union[int, sympy.Expr]
x_bitsize: Union[int, sympy.Expr]
base: Union[int, sympy.Expr] = python_int_field()
mod: Union[int, sympy.Expr] = python_int_field()
exp_bitsize: Union[int, sympy.Expr] = field()
x_bitsize: Union[int, sympy.Expr] = field()

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -119,6 +120,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
}

def on_classical_vals(self, exponent: int):
exponent = int_or_expr(exponent)
return {'exponent': exponent, 'x': (self.base**exponent) % self.mod}

def short_name(self) -> str:
Expand Down
12 changes: 8 additions & 4 deletions qualtran/bloqs/factoring/mod_exp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import pytest
import sympy

import qualtran.testing as qlt_testing
from qualtran import Bloq
from qualtran.bloqs.factoring.mod_exp import _modexp, _modexp_symb, ModExp
from qualtran.bloqs.factoring.mod_exp import _modexp, _modexp_small, _modexp_symb, ModExp
from qualtran.bloqs.factoring.mod_mul import CtrlModMul
from qualtran.bloqs.util_bloqs import Join, Split
from qualtran.resource_counting import SympySymbolAllocator
from qualtran.testing import execute_notebook


def test_mod_exp_consistent_classical():
Expand Down Expand Up @@ -85,6 +85,10 @@ def generalize(b: Bloq) -> Optional[Bloq]:
assert counts1 == counts2


def test_modexp_small(bloq_autotester):
bloq_autotester(_modexp_small)


def test_modexp(bloq_autotester):
bloq_autotester(_modexp)

Expand All @@ -95,9 +99,9 @@ def test_modexp_symb(bloq_autotester):

@pytest.mark.notebook
def test_intro_notebook():
execute_notebook('factoring-via-modexp')
qlt_testing.execute_notebook('factoring-via-modexp')


@pytest.mark.notebook
def test_notebook():
execute_notebook('mod_exp')
qlt_testing.execute_notebook('mod_exp')
12 changes: 7 additions & 5 deletions qualtran/bloqs/factoring/mod_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import attrs
import numpy as np
import sympy
from attrs import frozen
from attrs import field, frozen

from qualtran import (
Bloq,
Expand All @@ -33,6 +33,7 @@
)
from qualtran.bloqs.arithmetic.addition import SimpleAddConstant
from qualtran.bloqs.basic_gates import CNOT, CSwap, XGate
from qualtran.bloqs.factoring._big_int_helpers import int_or_expr, python_int_field
from qualtran.bloqs.factoring.mod_add import CtrlScaleModAdd
from qualtran.drawing import Circle, directional_text_box, WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
Expand All @@ -54,9 +55,9 @@ class CtrlModMul(Bloq):
x: The integer being multiplied
"""

k: Union[int, sympy.Expr]
mod: Union[int, sympy.Expr]
bitsize: Union[int, sympy.Expr]
k: Union[int, sympy.Expr] = python_int_field()
mod: Union[int, sympy.Expr] = python_int_field()
bitsize: Union[int, sympy.Expr] = field()

def __attrs_post_init__(self):
if isinstance(self.k, sympy.Expr):
Expand Down Expand Up @@ -101,11 +102,12 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
k = ssa.new_symbol('k')
return {(self._Add(k=k), 2), (CSwap(self.bitsize), 1)}

def on_classical_vals(self, ctrl, x) -> Dict[str, ClassicalValT]:
def on_classical_vals(self, ctrl: int, x: int) -> Dict[str, ClassicalValT]:
if ctrl == 0:
return {'ctrl': ctrl, 'x': x}

assert ctrl == 1, ctrl
x = int_or_expr(x)
return {'ctrl': ctrl, 'x': (x * self.k) % self.mod}

def short_name(self) -> str:
Expand Down
18 changes: 17 additions & 1 deletion qualtran/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

import qualtran.testing as qlt_testing
Expand Down Expand Up @@ -75,10 +75,26 @@ def assert_equivalent_bloq_example_counts_for_pytest(bloq_ex: BloqExample):
raise bce from bce


def assert_bloq_example_classical_action_for_pytest(bloq_ex: BloqExample):
rng = np.random.default_rng(seed=52)
try:
qlt_testing.assert_bloq_example_classical_action(bloq_ex, rng=rng)
except qlt_testing.BloqCheckException as bce:
if bce.check_result in [
qlt_testing.BloqCheckResult.UNVERIFIED,
qlt_testing.BloqCheckResult.NA,
qlt_testing.BloqCheckResult.MISSING,
]:
pytest.skip(bce.msg)

raise bce from bce


_TESTFUNCS = [
('make', assert_bloq_example_make_for_pytest),
('decompose', assert_bloq_example_decompose_for_pytest),
('counts', assert_equivalent_bloq_example_counts_for_pytest),
('classical', assert_bloq_example_classical_action_for_pytest),
]


Expand Down
Loading

0 comments on commit e5b3a3f

Please sign in to comment.