Skip to content

Commit

Permalink
Mypy issues part 3 (#898)
Browse files Browse the repository at this point in the history
* Mypy issues part 3

More miscellaenous mypy fixes.
  • Loading branch information
dstrain115 authored Apr 29, 2024
1 parent ec6bb8b commit 44a86a8
Show file tree
Hide file tree
Showing 32 changed files with 179 additions and 120 deletions.
8 changes: 5 additions & 3 deletions dev_tools/qualtran_dev_tools/jupyter_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ class NotebookSpecV2:

@directory.default
def _default_directory(self) -> str:
return str(Path(self.module.__file__).parent)
path = self.module.__file__
assert path is not None
return str(Path(path).parent)

@property
def path_stem(self):
Expand Down Expand Up @@ -156,7 +158,7 @@ class _PyCell(_Cell):
cell_id: str


def get_bloq_doc_cells(bloqdoc: BloqDocSpec, cid_prefix: str) -> List[_MarkdownCell]:
def get_bloq_doc_cells(bloqdoc: BloqDocSpec, cid_prefix: str) -> List[_Cell]:
"""Cells introducing the `bloq_cls`"""

md_doc: str = '\n'.join(get_markdown_docstring_lines(bloqdoc.bloq_cls))
Expand Down Expand Up @@ -230,7 +232,7 @@ def get_call_graph_cells(bloqdoc: BloqDocSpec, cid_prefix: str) -> List[_Cell]:


def get_cells(bloqdoc: BloqDocSpec) -> List[_Cell]:
cells = []
cells: List[_Cell] = []
cid_prefix = f'{bloqdoc.bloq_cls.__name__}'
cells += get_bloq_doc_cells(bloqdoc, cid_prefix)
cells += get_example_instances_cells(bloqdoc, cid_prefix)
Expand Down
2 changes: 1 addition & 1 deletion qualtran/_infra/bloq_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def bloq_example(_func: Callable[[], _BloqType], **kwargs: Any) -> BloqExample[_

@typing.overload
def bloq_example(
_func: None, *, generalizer: _GeneralizerType = lambda x: x
_func: None = None, *, generalizer: _GeneralizerType = lambda x: x
) -> Callable[[Callable[[], _BloqType]], BloqExample[_BloqType]]:
...

Expand Down
9 changes: 5 additions & 4 deletions qualtran/bloqs/chemistry/df/double_factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
QBit,
Register,
Signature,
Soquet,
SoquetT,
)
from qualtran.bloqs.basic_gates import CSwap, Hadamard, Toffoli
Expand Down Expand Up @@ -156,17 +157,17 @@ def short_name(self) -> str:
def build_composite_bloq(
self,
bb: 'BloqBuilder',
succ_l: SoquetT,
succ_l: Soquet,
l_ne_zero: SoquetT,
succ_p: SoquetT,
succ_p: Soquet,
p: SoquetT,
rot_aa: SoquetT,
spin: SoquetT,
xi: SoquetT,
offset: SoquetT,
rot: SoquetT,
rotations: SoquetT,
sys: SoquetT,
sys: NDArray[Soquet], # type: ignore[type-var]
) -> Dict[str, 'SoquetT']:
# 1st half
in_prep = InnerPrepareDoubleFactorization(
Expand Down Expand Up @@ -374,7 +375,7 @@ def signature(self) -> Signature:
def build_composite_bloq(
self,
bb: 'BloqBuilder',
ctrl: SoquetT,
ctrl: NDArray[Soquet], # type: ignore[type-var]
l: SoquetT,
p: SoquetT,
spin: SoquetT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
from attrs import evolve, field, frozen
from numpy.typing import NDArray

from qualtran import (
Bloq,
Expand Down Expand Up @@ -107,7 +108,7 @@ def short_name(self) -> str:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
if self.is_adjoint:
# inverting inequality tests at zero Toffoli.
return {}
return set()
else:
return {(Toffoli(), 6 * self.num_bits_t + 2)}

Expand Down Expand Up @@ -147,6 +148,7 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
(c_idx,) = soq.idx
filled = bool(self.cvs[c_idx])
return Circle(filled)
raise ValueError(f'Unknown name: {soq.reg.name}')

def build_composite_bloq(
self, bb: BloqBuilder, ctrl: SoquetT, sel: SoquetT, targets: SoquetT, junk: SoquetT
Expand Down Expand Up @@ -294,9 +296,9 @@ def build_composite_bloq(
r: SoquetT,
s: SoquetT,
mu: SoquetT,
nu_x: SoquetT,
nu_y: SoquetT,
nu_z: SoquetT,
nu_x: Soquet,
nu_y: Soquet,
nu_z: Soquet,
m: SoquetT,
succ_nu: SoquetT,
l: SoquetT,
Expand Down Expand Up @@ -462,7 +464,7 @@ def short_name(self) -> str:
def build_composite_bloq(
self,
bb: BloqBuilder,
ham_ctrl: SoquetT,
ham_ctrl: NDArray[Soquet], # type: ignore[type-var]
i_ne_j: SoquetT,
plus_t: SoquetT,
i: SoquetT,
Expand All @@ -478,7 +480,7 @@ def build_composite_bloq(
m: SoquetT,
l: SoquetT,
sys: SoquetT,
proj: SoquetT,
proj: NDArray[Soquet], # type: ignore[type-var]
) -> Dict[str, 'SoquetT']:
# ancilla for swaps from electronic and projectile system registers.
# we assume these are left in a clean state after SELECT operations
Expand Down Expand Up @@ -543,8 +545,10 @@ def build_composite_bloq(
j, sys, q = bb.add(
MultiplexedCSwap3D(self.num_bits_p, self.eta), sel=j, targets=sys, junk=q
)
_ = [bb.free(pi) for pi in p]
_ = [bb.free(qi) for qi in q]
for pi in p:
bb.free(pi)
for qi in q:
bb.free(qi)
ham_ctrl[:] = flag_t, flag_t_mean, flag_uv, flag_proj
bb.free(rl)
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
from attrs import frozen
from numpy.typing import NDArray

from qualtran import (
Bloq,
Expand All @@ -28,6 +29,7 @@
QBit,
Register,
Signature,
Soquet,
SoquetT,
)
from qualtran.bloqs.basic_gates import Toffoli
Expand Down Expand Up @@ -131,7 +133,7 @@ def signature(self) -> Signature:
@staticmethod
def _reshape_reg(
bb: BloqBuilder, in_reg: SoquetT, out_shape: Tuple[int, ...], bitsize: int
) -> SoquetT:
) -> NDArray[Soquet]: # type: ignore[type-var]
"""Reshape registers allocated as a big register.
Example:
Expand Down Expand Up @@ -162,6 +164,7 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
return TextBox('×(x)')
elif soq.reg.name == 'junk':
return TextBox('×(y)')
raise ValueError(f'Unknown name: {soq.reg.name}')

def short_name(self) -> str:
return 'MultiSwap'
Expand Down Expand Up @@ -279,9 +282,9 @@ def build_composite_bloq(
r: SoquetT,
s: SoquetT,
mu: SoquetT,
nu_x: SoquetT,
nu_y: SoquetT,
nu_z: SoquetT,
nu_x: Soquet,
nu_y: Soquet,
nu_z: Soquet,
m: SoquetT,
succ_nu: SoquetT,
l: SoquetT,
Expand Down Expand Up @@ -445,9 +448,9 @@ def build_composite_bloq(
r: SoquetT,
s: SoquetT,
mu: SoquetT,
nu_x: SoquetT,
nu_y: SoquetT,
nu_z: SoquetT,
nu_x: Soquet,
nu_y: Soquet,
nu_z: Soquet,
m: SoquetT,
l: SoquetT,
sys: SoquetT,
Expand Down Expand Up @@ -490,8 +493,10 @@ def build_composite_bloq(
j, sys, q = bb.add(
MultiplexedCSwap3D(self.num_bits_p, self.eta), sel=j, targets=sys, junk=q
)
_ = [bb.free(pi) for pi in p]
_ = [bb.free(qi) for qi in q]
for pi in p:
bb.free(pi)
for qi in q:
bb.free(qi)
bb.free(rl)
return {
'tuv': tuv,
Expand Down
8 changes: 5 additions & 3 deletions qualtran/bloqs/chemistry/sf/single_factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import numpy as np
from attrs import evolve, frozen
from numpy.typing import NDArray

from qualtran import (
Bloq,
Expand All @@ -37,6 +38,7 @@
QBit,
Register,
Signature,
Soquet,
SoquetT,
)
from qualtran._infra.data_types import BoundedQUInt
Expand Down Expand Up @@ -329,10 +331,10 @@ def build_composite_bloq(
self,
bb: 'BloqBuilder',
*,
ctrl: SoquetT,
ctrl: NDArray[Soquet], # type: ignore[type-var]
l: SoquetT,
pq: SoquetT,
rot_aa: SoquetT,
pq: NDArray[Soquet], # type: ignore[type-var]
rot_aa: NDArray[Soquet], # type: ignore[type-var]
swap_pq: SoquetT,
spin: SoquetT,
sys: SoquetT,
Expand Down
14 changes: 8 additions & 6 deletions qualtran/bloqs/chemistry/sparse/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
QAny,
QBit,
Register,
Soquet,
SoquetT,
)
from qualtran.bloqs.arithmetic.comparison import LessThanEqual
Expand All @@ -43,6 +44,7 @@
from qualtran.linalg.lcu_util import preprocess_lcu_coefficients_for_reversible_sampling

if TYPE_CHECKING:
from qualtran import Bloq
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator


Expand Down Expand Up @@ -102,12 +104,12 @@ def _add(p_indx: int, q_indx: int, r_indx: int, s_indx: int):
_add(q, q, q, p)
for p in range(num_spat):
_add(p, p, p, p)
eris_eight = np.array(eris_eight)
eris_eight_np = np.array(eris_eight)
pqrs_indx_np = np.array(pqrs_indx)
keep_indx = np.where(np.abs(eris_eight) > drop_element_thresh)
eris_eight = eris_eight[keep_indx]
keep_indx = np.where(np.abs(eris_eight_np) > drop_element_thresh)
eris_eight_np = eris_eight_np[keep_indx]
pqrs_indx_np = pqrs_indx_np[keep_indx[0]]
return np.concatenate((tpq_indx, pqrs_indx_np)), np.concatenate((tpq_sparse, eris_eight))
return np.concatenate((tpq_indx, pqrs_indx_np)), np.concatenate((tpq_sparse, eris_eight_np))


@frozen
Expand Down Expand Up @@ -319,8 +321,8 @@ def build_composite_bloq(
swap_rs: 'SoquetT',
swap_pqrs: 'SoquetT',
flag_1b: 'SoquetT',
alt_pqrs: 'SoquetT',
theta: 'SoquetT',
alt_pqrs: NDArray[Soquet], # type: ignore[type-var]
theta: NDArray[Soquet], # type: ignore[type-var]
keep: 'SoquetT',
less_than: 'SoquetT',
alt_flag_1b: 'SoquetT',
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/factoring/mod_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def build_composite_bloq(
sign = bb.add(XGate(), q=sign)

# Free the ancilla qubits.
junk_bit = bb.free(junk_bit)
sign = bb.free(sign)
bb.free(junk_bit)
bb.free(sign)

# Return the output registers.
return {'x': x, 'y': y}
Expand Down
7 changes: 5 additions & 2 deletions qualtran/bloqs/factoring/mod_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def build_composite_bloq(
self, bb: 'BloqBuilder', ctrl: 'SoquetT', x: 'SoquetT'
) -> Dict[str, 'SoquetT']:
k = self.k
neg_k_inv = -pow(k, -1, mod=self.mod)
if isinstance(self.mod, sympy.Expr) or isinstance(k, sympy.Expr):
neg_k_inv = sympy.Mod(sympy.Pow(k, -1), self.mod)
else:
neg_k_inv = -pow(k, -1, mod=self.mod)

# We store the result of the CtrlScaleModAdd into this new register
# and then clear the original `x` register by multiplying in the inverse.
Expand Down Expand Up @@ -148,7 +151,7 @@ def signature(self) -> 'Signature':
def on_classical_vals(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT']:
return {'x': (2 * x) % self.p}

def build_composite_bloq(self, bb: 'BloqBuilder', x: SoquetT) -> Dict[str, 'SoquetT']:
def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet) -> Dict[str, 'SoquetT']:

# Allocate ancilla bits for sign and double.
lower_bit = bb.allocate(n=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Dict, Tuple, TYPE_CHECKING, Union
from typing import cast, Dict, Tuple, TYPE_CHECKING, Union

import numpy as np
from attrs import field, frozen
Expand Down Expand Up @@ -111,9 +111,10 @@ def approx_cos(self) -> Union[NDArray[np.complex_], Shaped]:
if self.is_symbolic():
return Shaped((2 * self.degree + 1,))

poly = approx_exp_cos_by_jacobi_anger(-self.t * self.alpha, degree=self.degree)
poly = approx_exp_cos_by_jacobi_anger(-self.t * self.alpha, degree=cast(int, self.degree))

# TODO(#860) current scaling method does not compute true maximum, so we scale down a bit more by (1 - 2\eps)
poly = scale_down_to_qsp_polynomial(poly) * (1 - 2 * self.precision)
poly = scale_down_to_qsp_polynomial(list(poly)) * (1 - 2 * self.precision)
return poly

@cached_property
Expand Down Expand Up @@ -168,7 +169,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str
bb.free(soq)
else:
for soq_element in soq:
bb.free(soq)
bb.free(cast(Soquet, soq))

return soqs

Expand Down
10 changes: 5 additions & 5 deletions qualtran/bloqs/mcmt/and_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import itertools
from functools import cached_property
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Tuple
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Tuple, Union

import attrs
import cirq
Expand Down Expand Up @@ -81,8 +81,8 @@ class And(GateWithRegisters):
[Verifying Measurement Based Uncomputation](https://algassert.com/post/1903). Gidney, C. 2019.
"""

cv1: int = 1
cv2: int = 1
cv1: Union[int, sympy.Expr] = 1
cv2: Union[int, sympy.Expr] = 1
uncompute: bool = False

@cached_property
Expand Down Expand Up @@ -229,7 +229,7 @@ def _and_bloq() -> And:
)


def _to_tuple(x: Iterable[int]) -> Sequence[int]:
def _to_tuple(x: Iterable[Union[int, sympy.Expr]]) -> Sequence[Union[int, sympy.Expr]]:
return tuple(x)


Expand All @@ -247,7 +247,7 @@ class MultiAnd(Bloq):
target [right]: The output bit.
"""

cvs: Tuple[int, ...] = field(converter=_to_tuple)
cvs: Tuple[Union[int, sympy.Expr], ...] = field(converter=_to_tuple)

@cvs.validator
def _validate_cvs(self, field, val):
Expand Down
Loading

0 comments on commit 44a86a8

Please sign in to comment.