From 8ba6e84cf442d4dfd7ba7579f88169e9fd8e874b Mon Sep 17 00:00:00 2001 From: Charles Yuan Date: Thu, 18 Jul 2024 10:06:05 -0700 Subject: [PATCH] Make static type checking work better with symbolics --- qualtran/bloqs/arithmetic/comparison.py | 4 +--- qualtran/bloqs/arithmetic/permutation.py | 6 ++--- qualtran/bloqs/arithmetic/sorting.py | 13 ++++------- .../block_encoding/linear_combination.py | 23 ++++++++----------- qualtran/bloqs/block_encoding/product.py | 9 ++------ .../bloqs/block_encoding/tensor_product.py | 8 +++---- .../bloqs/data_loading/select_swap_qrom.py | 4 ++-- qualtran/bloqs/qft/qft_text_book.py | 4 ++-- qualtran/bloqs/rotations/phase_gradient.py | 4 ++-- .../prepare_uniform_superposition.py | 1 - .../state_preparation_via_rotation.py | 1 - qualtran/symbolics/types.py | 18 ++++++++++++++- 12 files changed, 47 insertions(+), 48 deletions(-) diff --git a/qualtran/bloqs/arithmetic/comparison.py b/qualtran/bloqs/arithmetic/comparison.py index 0408d0fac..0190dd348 100644 --- a/qualtran/bloqs/arithmetic/comparison.py +++ b/qualtran/bloqs/arithmetic/comparison.py @@ -962,11 +962,9 @@ def is_symbolic(self): @property def bits_k(self) -> Union[tuple[int, ...], HasLength]: - if self.is_symbolic(): + if is_symbolic(self.bitsize) or is_symbolic(self.val): return HasLength(self.bitsize) - assert not isinstance(self.bitsize, sympy.Expr) - assert not isinstance(self.val, sympy.Expr) return tuple(QUInt(self.bitsize).to_bits(self.val)) def build_composite_bloq( diff --git a/qualtran/bloqs/arithmetic/permutation.py b/qualtran/bloqs/arithmetic/permutation.py index 8738fab1c..9c7a145e9 100644 --- a/qualtran/bloqs/arithmetic/permutation.py +++ b/qualtran/bloqs/arithmetic/permutation.py @@ -108,9 +108,8 @@ def is_symbolic(self): return is_symbolic(self.N, self.cycle) def build_composite_bloq(self, bb: 'BloqBuilder', x: 'SoquetT') -> dict[str, 'SoquetT']: - if self.is_symbolic(): + if is_symbolic(self.cycle): raise DecomposeTypeError(f"cannot decompose symbolic {self}") - assert not isinstance(self.cycle, Shaped) a: 'SoquetT' = bb.allocate(dtype=QBit()) @@ -253,10 +252,9 @@ def from_cycle_lengths(cls, N: SymbolicInt, cycle_lengths: tuple[SymbolicInt, .. return cls(N, cycles) def build_composite_bloq(self, bb: 'BloqBuilder', x: 'Soquet') -> dict[str, 'SoquetT']: - if self.is_symbolic(): + if is_symbolic(self.cycles): raise DecomposeTypeError(f"cannot decompose symbolic {self}") - assert not isinstance(self.cycles, Shaped) for cycle in self.cycles: x = bb.add(PermutationCycle(self.N, cycle), x=x) diff --git a/qualtran/bloqs/arithmetic/sorting.py b/qualtran/bloqs/arithmetic/sorting.py index 85e0f0c60..6e0fabf21 100644 --- a/qualtran/bloqs/arithmetic/sorting.py +++ b/qualtran/bloqs/arithmetic/sorting.py @@ -147,7 +147,7 @@ def is_symbolic(self): @cached_property def num_comparisons(self) -> SymbolicInt: """Number of `Comparator` gates used in the decomposition""" - if is_symbolic(self.k, self.offset): + if is_symbolic(self.k) or is_symbolic(self.offset): return self.k // 2 # upper bound full = (self.k // (2 * self.offset)) * self.offset @@ -155,12 +155,11 @@ def num_comparisons(self) -> SymbolicInt: return full + max(rest - self.offset, 0) def build_composite_bloq(self, bb: 'BloqBuilder', xs: 'SoquetT') -> Dict[str, 'SoquetT']: - if is_symbolic(self.k, self.offset): + if is_symbolic(self.k) or is_symbolic(self.offset): raise DecomposeTypeError(f"Cannot decompose symbolic {self=}") - # make mypy happy - k = int(self.k) - offset = int(self.offset) + k = self.k + offset = self.offset assert isinstance(xs, np.ndarray) comp = Comparator(self.bitsize) @@ -225,7 +224,6 @@ class BitonicMerge(Bloq): def __attrs_post_init__(self): k = self.half_length if not is_symbolic(k): - assert not isinstance(k, sympy.Expr) assert k >= 1, "length of input lists must be positive" # TODO(#1090) support non-power-of-two input lengths assert (k & (k - 1)) == 0, "length of input lists must be a power of 2" @@ -260,7 +258,7 @@ def build_composite_bloq( assert isinstance(xs, np.ndarray) assert isinstance(ys, np.ndarray) - k = int(self.half_length) + k = self.half_length first_round_junk = [] for i in range(k): @@ -340,7 +338,6 @@ class BitonicSort(Bloq): def __attrs_post_init__(self): k = self.k if not is_symbolic(k): - assert not isinstance(k, sympy.Expr) assert k >= 1, f"length of input list must be positive, got {k=}" # TODO(#1090) support non-power-of-two input lengths assert (k & (k - 1)) == 0, f"length of input list must be a power of 2, got {k=}" diff --git a/qualtran/bloqs/block_encoding/linear_combination.py b/qualtran/bloqs/block_encoding/linear_combination.py index 2ec187b6f..b8a8061d2 100644 --- a/qualtran/bloqs/block_encoding/linear_combination.py +++ b/qualtran/bloqs/block_encoding/linear_combination.py @@ -192,8 +192,7 @@ def prepare(self) -> BlackBoxPrepare: raise DecomposeTypeError(f"Cannot decompose symbolic {self=}") alt, keep, mu = preprocess_probabilities_for_reversible_sampling( - unnormalized_probabilities=tuple(self.rescaled_lambd), - sub_bit_precision=cast(int, self.lambd_bits), + unnormalized_probabilities=tuple(self.rescaled_lambd), sub_bit_precision=self.lambd_bits ) N = len(self.rescaled_lambd) @@ -224,14 +223,14 @@ def select(self) -> BlackBoxSelect: or is_symbolic(self.resource_bitsize) ): raise DecomposeTypeError(f"Cannot decompose symbolic {self=}") - assert isinstance(self.be_ancilla_bitsize, int) - assert isinstance(self.be_resource_bitsize, int) + assert not is_symbolic(self.be_ancilla_bitsize) + assert not is_symbolic(self.be_resource_bitsize) # make all bloqs have same ancilla and resource registers bloqs = [] for be in self.signed_block_encodings: - assert isinstance(be.ancilla_bitsize, int) - assert isinstance(be.resource_bitsize, int) + assert not is_symbolic(be.ancilla_bitsize) + assert not is_symbolic(be.resource_bitsize) partitions: List[Tuple[Register, List[Union[str, Unused]]]] = [ (Register("system", QAny(self.system_bitsize)), ["system"]) @@ -267,17 +266,15 @@ def build_composite_bloq( or is_symbolic(self.resource_bitsize) ): raise DecomposeTypeError(f"Cannot decompose symbolic {self=}") - assert isinstance(self.be_ancilla_bitsize, int) - assert isinstance(self.ancilla_bitsize, int) - assert isinstance(self.be_resource_bitsize, int) - assert isinstance(self.resource_bitsize, int) + assert not is_symbolic(self.be_ancilla_bitsize) + assert not is_symbolic(self.be_resource_bitsize) # partition ancilla register be_system_soqs: Dict[str, SoquetT] = {"system": system} anc_regs = [Register("selection", QAny(self.prepare.selection_bitsize))] if self.be_ancilla_bitsize > 0: anc_regs.append(Register("ancilla", QAny(self.be_ancilla_bitsize))) - anc_part = Partition(cast(int, self.ancilla_bitsize), tuple(anc_regs)) + anc_part = Partition(self.ancilla_bitsize, tuple(anc_regs)) anc_soqs = bb.add_d(anc_part, x=ancilla) if self.be_ancilla_bitsize > 0: be_system_soqs["ancilla"] = anc_soqs.pop("ancilla") @@ -290,7 +287,7 @@ def build_composite_bloq( res_regs.append(Register("resource", QAny(self.be_resource_bitsize))) if self.prepare.junk_bitsize > 0: res_regs.append(Register("prepare_junk", QAny(self.prepare.junk_bitsize))) - res_part = Partition(cast(int, self.resource_bitsize), tuple(res_regs)) + res_part = Partition(self.resource_bitsize, tuple(res_regs)) res_soqs = bb.add_d(res_part, x=soqs.pop("resource")) if self.be_resource_bitsize > 0: be_system_soqs["resource"] = res_soqs.pop("resource") @@ -303,7 +300,7 @@ def build_composite_bloq( be_regs.append(Register("ancilla", QAny(self.be_ancilla_bitsize))) if self.be_resource_bitsize > 0: be_regs.append(Register("resource", QAny(self.be_resource_bitsize))) - be_part = Partition(cast(int, self.select.system_bitsize), tuple(be_regs)) + be_part = Partition(self.select.system_bitsize, tuple(be_regs)) prepare_soqs = bb.add_d(self.prepare, **prepare_in_soqs) select_out_soqs = bb.add_d( diff --git a/qualtran/bloqs/block_encoding/product.py b/qualtran/bloqs/block_encoding/product.py index b1ab4b079..111cc9857 100644 --- a/qualtran/bloqs/block_encoding/product.py +++ b/qualtran/bloqs/block_encoding/product.py @@ -146,11 +146,6 @@ def build_composite_bloq( or is_symbolic(self.resource_bitsize) ): raise DecomposeTypeError(f"Cannot decompose symbolic {self=}") - assert ( - isinstance(self.system_bitsize, int) - and isinstance(self.ancilla_bitsize, int) - and isinstance(self.resource_bitsize, int) - ) n = len(self.block_encodings) if self.ancilla_bitsize > 0: @@ -176,8 +171,8 @@ def build_composite_bloq( # connect constituent bloqs for i, u in enumerate(reversed(self.block_encodings)): - assert isinstance(u.ancilla_bitsize, int) - assert isinstance(u.resource_bitsize, int) + assert not is_symbolic(u.ancilla_bitsize) + assert not is_symbolic(u.resource_bitsize) u_soqs = {"system": system} partition: List[Tuple[Register, List[Union[str, Unused]]]] = [ (Register("system", dtype=QAny(u.system_bitsize)), ["system"]) diff --git a/qualtran/bloqs/block_encoding/tensor_product.py b/qualtran/bloqs/block_encoding/tensor_product.py index ca869bb13..c5cf7631e 100644 --- a/qualtran/bloqs/block_encoding/tensor_product.py +++ b/qualtran/bloqs/block_encoding/tensor_product.py @@ -14,7 +14,7 @@ from collections import Counter from functools import cached_property -from typing import cast, Dict, Set, Tuple +from typing import Dict, Set, Tuple from attrs import evolve, field, frozen, validators @@ -142,13 +142,13 @@ def build_composite_bloq( if "resource" in u.signature._lefts ) - sys_part = Partition(cast(int, self.system_bitsize), regs=sys_regs) + sys_part = Partition(self.system_bitsize, regs=sys_regs) sys_out_regs = list(bb.add_t(sys_part, x=system)) if len(anc_regs) > 0: - anc_part = Partition(cast(int, self.ancilla_bitsize), regs=anc_regs) + anc_part = Partition(self.ancilla_bitsize, regs=anc_regs) anc_out_regs = list(bb.add_t(anc_part, x=soqs["ancilla"])) if len(res_regs) > 0: - res_part = Partition(cast(int, self.resource_bitsize), regs=res_regs) + res_part = Partition(self.resource_bitsize, regs=res_regs) res_out_regs = list(bb.add_t(res_part, x=soqs["resource"])) sys_i = 0 anc_i = 0 diff --git a/qualtran/bloqs/data_loading/select_swap_qrom.py b/qualtran/bloqs/data_loading/select_swap_qrom.py index d2710f338..7da76b9ae 100644 --- a/qualtran/bloqs/data_loading/select_swap_qrom.py +++ b/qualtran/bloqs/data_loading/select_swap_qrom.py @@ -29,7 +29,7 @@ from qualtran.bloqs.data_loading.qrom_base import QROMBase from qualtran.bloqs.swap_network import SwapWithZero from qualtran.drawing import Circle, Text, TextBox, WireSymbol -from qualtran.symbolics import ceil, is_symbolic, log2, prod, SymbolicInt +from qualtran.symbolics import ceil, is_symbolic, log2, prod, SymbolicFloat, SymbolicInt if TYPE_CHECKING: from qualtran import Bloq @@ -46,7 +46,7 @@ def find_optimal_log_block_size( * iteration_length/2^k + target_bitsize*(2^k - 1) is minimized. The corresponding block size for SelectSwapQROM would be 2^k. """ - k = 0.5 * log2(iteration_length / target_bitsize) + k: SymbolicFloat = 0.5 * log2(iteration_length / target_bitsize) if is_symbolic(k): return ceil(k) diff --git a/qualtran/bloqs/qft/qft_text_book.py b/qualtran/bloqs/qft/qft_text_book.py index afe4228ce..96c021b1c 100644 --- a/qualtran/bloqs/qft/qft_text_book.py +++ b/qualtran/bloqs/qft/qft_text_book.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import cached_property -from typing import cast, Iterator, Set +from typing import Iterator, Set import attrs import cirq @@ -94,7 +94,7 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> Set['BloqCountT']: ) } else: - for i in range(1, cast(int, self.bitsize)): + for i in range(1, self.bitsize): ret |= {(PhaseGradientUnitary(i, exponent=0.5, is_controlled=True), 1)} if self.with_reverse: ret |= {(TwoBitSwap(), self.bitsize // 2)} diff --git a/qualtran/bloqs/rotations/phase_gradient.py b/qualtran/bloqs/rotations/phase_gradient.py index 0107a51f9..4a740115a 100644 --- a/qualtran/bloqs/rotations/phase_gradient.py +++ b/qualtran/bloqs/rotations/phase_gradient.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import cached_property -from typing import cast, Dict, Iterator, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union +from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union import attrs import cirq @@ -138,7 +138,7 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> Set['BloqCountT']: } ret: Set['BloqCountT'] = set() - for i in range(cast(int, self.bitsize)): + for i in range(self.bitsize): ret.add((gate(exponent=self.exponent / 2**i, eps=self.eps / self.bitsize), 1)) return ret diff --git a/qualtran/bloqs/state_preparation/prepare_uniform_superposition.py b/qualtran/bloqs/state_preparation/prepare_uniform_superposition.py index 104b88aed..ca4bd70e8 100644 --- a/qualtran/bloqs/state_preparation/prepare_uniform_superposition.py +++ b/qualtran/bloqs/state_preparation/prepare_uniform_superposition.py @@ -96,7 +96,6 @@ def k_l_logL(self) -> Tuple[SymbolicInt, SymbolicInt, SymbolicInt]: k, n, logL = 0, self.n, bit_length(self.n - 1) if is_symbolic(n): return 0, self.n, bit_length(self.n - 1) - n = int(n) while n > 1 and n % 2 == 0: k += 1 logL -= 1 diff --git a/qualtran/bloqs/state_preparation/state_preparation_via_rotation.py b/qualtran/bloqs/state_preparation/state_preparation_via_rotation.py index b7834bb30..91c1063c6 100644 --- a/qualtran/bloqs/state_preparation/state_preparation_via_rotation.py +++ b/qualtran/bloqs/state_preparation/state_preparation_via_rotation.py @@ -145,7 +145,6 @@ class StatePreparationViaRotations(GateWithRegisters): def __attrs_post_init__(self): if is_symbolic(self.state_coefficients): return - assert isinstance(self.state_coefficients, tuple) # a valid quantum state has a number of coefficients that is a power of two assert slen(self.state_coefficients) == 2**self.state_bitsize # negative number of control bits is not allowed diff --git a/qualtran/symbolics/types.py b/qualtran/symbolics/types.py index 6b4c21017..3bc42e229 100644 --- a/qualtran/symbolics/types.py +++ b/qualtran/symbolics/types.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import overload, TypeVar, Union import sympy from attrs import field, frozen, validators from cirq._doc import document +from typing_extensions import TypeIs SymbolicFloat = Union[float, sympy.Expr] document(SymbolicFloat, """A floating point value or a sympy expression.""") @@ -63,7 +64,22 @@ def is_symbolic(self): return True +T = TypeVar('T') + + +@overload +def is_symbolic( + arg: Union[T, sympy.Expr, Shaped, HasLength], / +) -> TypeIs[Union[sympy.Expr, Shaped, HasLength]]: + ... + + +@overload def is_symbolic(*args) -> bool: + ... + + +def is_symbolic(*args) -> Union[TypeIs[Union[sympy.Expr, Shaped, HasLength]], bool]: """Returns whether the inputs contain any symbolic object. Returns: