Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy issues part 3 #898

Merged
merged 7 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions dev_tools/qualtran_dev_tools/jupyter_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ class NotebookSpecV2:

@directory.default
def _default_directory(self) -> str:
return str(Path(self.module.__file__).parent)
path = self.module.__file__
assert path is not None
return str(Path(path).parent)

@property
def path_stem(self):
Expand Down Expand Up @@ -156,7 +158,7 @@ class _PyCell(_Cell):
cell_id: str


def get_bloq_doc_cells(bloqdoc: BloqDocSpec, cid_prefix: str) -> List[_MarkdownCell]:
def get_bloq_doc_cells(bloqdoc: BloqDocSpec, cid_prefix: str) -> List[_Cell]:
"""Cells introducing the `bloq_cls`"""

md_doc: str = '\n'.join(get_markdown_docstring_lines(bloqdoc.bloq_cls))
Expand Down Expand Up @@ -230,7 +232,7 @@ def get_call_graph_cells(bloqdoc: BloqDocSpec, cid_prefix: str) -> List[_Cell]:


def get_cells(bloqdoc: BloqDocSpec) -> List[_Cell]:
cells = []
cells: List[_Cell] = []
cid_prefix = f'{bloqdoc.bloq_cls.__name__}'
cells += get_bloq_doc_cells(bloqdoc, cid_prefix)
cells += get_example_instances_cells(bloqdoc, cid_prefix)
Expand Down
2 changes: 1 addition & 1 deletion qualtran/_infra/bloq_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def bloq_example(_func: Callable[[], _BloqType], **kwargs: Any) -> BloqExample[_

@typing.overload
def bloq_example(
_func: None, *, generalizer: _GeneralizerType = lambda x: x
_func: None = None, *, generalizer: _GeneralizerType = lambda x: x
) -> Callable[[Callable[[], _BloqType]], BloqExample[_BloqType]]:
...

Expand Down
9 changes: 5 additions & 4 deletions qualtran/bloqs/chemistry/df/double_factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
QBit,
Register,
Signature,
Soquet,
SoquetT,
)
from qualtran.bloqs.basic_gates import CSwap, Hadamard, Toffoli
Expand Down Expand Up @@ -156,17 +157,17 @@ def short_name(self) -> str:
def build_composite_bloq(
self,
bb: 'BloqBuilder',
succ_l: SoquetT,
succ_l: Soquet,
l_ne_zero: SoquetT,
succ_p: SoquetT,
succ_p: Soquet,
p: SoquetT,
rot_aa: SoquetT,
spin: SoquetT,
xi: SoquetT,
offset: SoquetT,
rot: SoquetT,
rotations: SoquetT,
sys: SoquetT,
sys: NDArray[Soquet], # type: ignore[type-var]
) -> Dict[str, 'SoquetT']:
# 1st half
in_prep = InnerPrepareDoubleFactorization(
Expand Down Expand Up @@ -374,7 +375,7 @@ def signature(self) -> Signature:
def build_composite_bloq(
self,
bb: 'BloqBuilder',
ctrl: SoquetT,
ctrl: NDArray[Soquet], # type: ignore[type-var]
l: SoquetT,
p: SoquetT,
spin: SoquetT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
from attrs import evolve, field, frozen
from numpy.typing import NDArray

from qualtran import (
Bloq,
Expand Down Expand Up @@ -107,7 +108,7 @@ def short_name(self) -> str:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
if self.is_adjoint:
# inverting inequality tests at zero Toffoli.
return {}
return set()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops

else:
return {(Toffoli(), 6 * self.num_bits_t + 2)}

Expand Down Expand Up @@ -147,6 +148,7 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
(c_idx,) = soq.idx
filled = bool(self.cvs[c_idx])
return Circle(filled)
raise ValueError(f'Unknown name: {soq.reg.name}')

def build_composite_bloq(
self, bb: BloqBuilder, ctrl: SoquetT, sel: SoquetT, targets: SoquetT, junk: SoquetT
Expand Down Expand Up @@ -294,9 +296,9 @@ def build_composite_bloq(
r: SoquetT,
s: SoquetT,
mu: SoquetT,
nu_x: SoquetT,
nu_y: SoquetT,
nu_z: SoquetT,
nu_x: Soquet,
nu_y: Soquet,
nu_z: Soquet,
m: SoquetT,
succ_nu: SoquetT,
l: SoquetT,
Expand Down Expand Up @@ -462,7 +464,7 @@ def short_name(self) -> str:
def build_composite_bloq(
self,
bb: BloqBuilder,
ham_ctrl: SoquetT,
ham_ctrl: NDArray[Soquet], # type: ignore[type-var]
i_ne_j: SoquetT,
plus_t: SoquetT,
i: SoquetT,
Expand All @@ -478,7 +480,7 @@ def build_composite_bloq(
m: SoquetT,
l: SoquetT,
sys: SoquetT,
proj: SoquetT,
proj: NDArray[Soquet], # type: ignore[type-var]
) -> Dict[str, 'SoquetT']:
# ancilla for swaps from electronic and projectile system registers.
# we assume these are left in a clean state after SELECT operations
Expand Down Expand Up @@ -543,8 +545,10 @@ def build_composite_bloq(
j, sys, q = bb.add(
MultiplexedCSwap3D(self.num_bits_p, self.eta), sel=j, targets=sys, junk=q
)
_ = [bb.free(pi) for pi in p]
_ = [bb.free(qi) for qi in q]
for pi in p:
bb.free(pi)
for qi in q:
bb.free(qi)
ham_ctrl[:] = flag_t, flag_t_mean, flag_uv, flag_proj
bb.free(rl)
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
from attrs import frozen
from numpy.typing import NDArray

from qualtran import (
Bloq,
Expand All @@ -28,6 +29,7 @@
QBit,
Register,
Signature,
Soquet,
SoquetT,
)
from qualtran.bloqs.basic_gates import Toffoli
Expand Down Expand Up @@ -131,7 +133,7 @@ def signature(self) -> Signature:
@staticmethod
def _reshape_reg(
bb: BloqBuilder, in_reg: SoquetT, out_shape: Tuple[int, ...], bitsize: int
) -> SoquetT:
) -> NDArray[Soquet]: # type: ignore[type-var]
"""Reshape registers allocated as a big register.

Example:
Expand Down Expand Up @@ -162,6 +164,7 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
return TextBox('×(x)')
elif soq.reg.name == 'junk':
return TextBox('×(y)')
raise ValueError(f'Unknown name: {soq.reg.name}')

def short_name(self) -> str:
return 'MultiSwap'
Expand Down Expand Up @@ -279,9 +282,9 @@ def build_composite_bloq(
r: SoquetT,
s: SoquetT,
mu: SoquetT,
nu_x: SoquetT,
nu_y: SoquetT,
nu_z: SoquetT,
nu_x: Soquet,
nu_y: Soquet,
nu_z: Soquet,
m: SoquetT,
succ_nu: SoquetT,
l: SoquetT,
Expand Down Expand Up @@ -445,9 +448,9 @@ def build_composite_bloq(
r: SoquetT,
s: SoquetT,
mu: SoquetT,
nu_x: SoquetT,
nu_y: SoquetT,
nu_z: SoquetT,
nu_x: Soquet,
nu_y: Soquet,
nu_z: Soquet,
m: SoquetT,
l: SoquetT,
sys: SoquetT,
Expand Down Expand Up @@ -490,8 +493,10 @@ def build_composite_bloq(
j, sys, q = bb.add(
MultiplexedCSwap3D(self.num_bits_p, self.eta), sel=j, targets=sys, junk=q
)
_ = [bb.free(pi) for pi in p]
_ = [bb.free(qi) for qi in q]
for pi in p:
bb.free(pi)
for qi in q:
bb.free(qi)
bb.free(rl)
return {
'tuv': tuv,
Expand Down
8 changes: 5 additions & 3 deletions qualtran/bloqs/chemistry/sf/single_factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import numpy as np
from attrs import evolve, frozen
from numpy.typing import NDArray

from qualtran import (
Bloq,
Expand All @@ -37,6 +38,7 @@
QBit,
Register,
Signature,
Soquet,
SoquetT,
)
from qualtran._infra.data_types import BoundedQUInt
Expand Down Expand Up @@ -329,10 +331,10 @@ def build_composite_bloq(
self,
bb: 'BloqBuilder',
*,
ctrl: SoquetT,
ctrl: NDArray[Soquet], # type: ignore[type-var]
l: SoquetT,
pq: SoquetT,
rot_aa: SoquetT,
pq: NDArray[Soquet], # type: ignore[type-var]
rot_aa: NDArray[Soquet], # type: ignore[type-var]
swap_pq: SoquetT,
spin: SoquetT,
sys: SoquetT,
Expand Down
14 changes: 8 additions & 6 deletions qualtran/bloqs/chemistry/sparse/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
QAny,
QBit,
Register,
Soquet,
SoquetT,
)
from qualtran.bloqs.arithmetic.comparison import LessThanEqual
Expand All @@ -43,6 +44,7 @@
from qualtran.linalg.lcu_util import preprocess_lcu_coefficients_for_reversible_sampling

if TYPE_CHECKING:
from qualtran import Bloq
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator


Expand Down Expand Up @@ -102,12 +104,12 @@ def _add(p_indx: int, q_indx: int, r_indx: int, s_indx: int):
_add(q, q, q, p)
for p in range(num_spat):
_add(p, p, p, p)
eris_eight = np.array(eris_eight)
eris_eight_np = np.array(eris_eight)
pqrs_indx_np = np.array(pqrs_indx)
keep_indx = np.where(np.abs(eris_eight) > drop_element_thresh)
eris_eight = eris_eight[keep_indx]
keep_indx = np.where(np.abs(eris_eight_np) > drop_element_thresh)
eris_eight_np = eris_eight_np[keep_indx]
pqrs_indx_np = pqrs_indx_np[keep_indx[0]]
return np.concatenate((tpq_indx, pqrs_indx_np)), np.concatenate((tpq_sparse, eris_eight))
return np.concatenate((tpq_indx, pqrs_indx_np)), np.concatenate((tpq_sparse, eris_eight_np))


@frozen
Expand Down Expand Up @@ -319,8 +321,8 @@ def build_composite_bloq(
swap_rs: 'SoquetT',
swap_pqrs: 'SoquetT',
flag_1b: 'SoquetT',
alt_pqrs: 'SoquetT',
theta: 'SoquetT',
alt_pqrs: NDArray[Soquet], # type: ignore[type-var]
theta: NDArray[Soquet], # type: ignore[type-var]
keep: 'SoquetT',
less_than: 'SoquetT',
alt_flag_1b: 'SoquetT',
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/factoring/mod_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def build_composite_bloq(
sign = bb.add(XGate(), q=sign)

# Free the ancilla qubits.
junk_bit = bb.free(junk_bit)
sign = bb.free(sign)
bb.free(junk_bit)
bb.free(sign)

# Return the output registers.
return {'x': x, 'y': y}
Expand Down
7 changes: 5 additions & 2 deletions qualtran/bloqs/factoring/mod_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def build_composite_bloq(
self, bb: 'BloqBuilder', ctrl: 'SoquetT', x: 'SoquetT'
) -> Dict[str, 'SoquetT']:
k = self.k
neg_k_inv = -pow(k, -1, mod=self.mod)
if isinstance(self.mod, sympy.Expr) or isinstance(k, sympy.Expr):
neg_k_inv = sympy.Mod(sympy.Pow(k, -1), self.mod)
else:
neg_k_inv = -pow(k, -1, mod=self.mod)

# We store the result of the CtrlScaleModAdd into this new register
# and then clear the original `x` register by multiplying in the inverse.
Expand Down Expand Up @@ -148,7 +151,7 @@ def signature(self) -> 'Signature':
def on_classical_vals(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT']:
return {'x': (2 * x) % self.p}

def build_composite_bloq(self, bb: 'BloqBuilder', x: SoquetT) -> Dict[str, 'SoquetT']:
def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet) -> Dict[str, 'SoquetT']:

# Allocate ancilla bits for sign and double.
lower_bit = bb.allocate(n=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Dict, Tuple, TYPE_CHECKING, Union
from typing import cast, Dict, Tuple, TYPE_CHECKING, Union

import numpy as np
from attrs import field, frozen
Expand Down Expand Up @@ -111,9 +111,10 @@ def approx_cos(self) -> Union[NDArray[np.complex_], Shaped]:
if self.is_symbolic():
return Shaped((2 * self.degree + 1,))

poly = approx_exp_cos_by_jacobi_anger(-self.t * self.alpha, degree=self.degree)
poly = approx_exp_cos_by_jacobi_anger(-self.t * self.alpha, degree=cast(int, self.degree))

# TODO(#860) current scaling method does not compute true maximum, so we scale down a bit more by (1 - 2\eps)
poly = scale_down_to_qsp_polynomial(poly) * (1 - 2 * self.precision)
poly = scale_down_to_qsp_polynomial(list(poly)) * (1 - 2 * self.precision)
return poly

@cached_property
Expand Down Expand Up @@ -168,7 +169,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str
bb.free(soq)
else:
for soq_element in soq:
bb.free(soq)
bb.free(cast(Soquet, soq))

return soqs

Expand Down
10 changes: 5 additions & 5 deletions qualtran/bloqs/mcmt/and_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import itertools
from functools import cached_property
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Tuple
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Tuple, Union

import attrs
import cirq
Expand Down Expand Up @@ -81,8 +81,8 @@ class And(GateWithRegisters):
[Verifying Measurement Based Uncomputation](https://algassert.com/post/1903). Gidney, C. 2019.
"""

cv1: int = 1
cv2: int = 1
cv1: Union[int, sympy.Expr] = 1
cv2: Union[int, sympy.Expr] = 1
uncompute: bool = False

@cached_property
Expand Down Expand Up @@ -229,7 +229,7 @@ def _and_bloq() -> And:
)


def _to_tuple(x: Iterable[int]) -> Sequence[int]:
def _to_tuple(x: Iterable[Union[int, sympy.Expr]]) -> Sequence[Union[int, sympy.Expr]]:
return tuple(x)


Expand All @@ -247,7 +247,7 @@ class MultiAnd(Bloq):
target [right]: The output bit.
"""

cvs: Tuple[int, ...] = field(converter=_to_tuple)
cvs: Tuple[Union[int, sympy.Expr], ...] = field(converter=_to_tuple)

@cvs.validator
def _validate_cvs(self, field, val):
Expand Down
Loading
Loading