diff --git a/CHANGELOG.md b/CHANGELOG.md index 9dcb15d3b..7acc744f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ Changelog - Compile to XY gates as well as CZ gates on dummy QVMs (@ecpeterson, gh-1151). - `QAM.write_memory` now accepts either a `Sequence` of values or a single value (@tommy-moffat, gh-1114). +- Added type hints for all remaining top-level files (@karalekas, gh-1150). ### Bugfixes diff --git a/Makefile b/Makefile index a2ac2af9d..2c72115d4 100644 --- a/Makefile +++ b/Makefile @@ -16,9 +16,11 @@ check-format: # The dream is to one day run mypy on the whole tree. For now, limit checks to known-good files. .PHONY: check-types check-types: - mypy pyquil/gates.py pyquil/gate_matrices.py pyquil/noise.py pyquil/numpy_simulator.py \ + mypy pyquil/gate_matrices.py pyquil/gates.py pyquil/noise.py pyquil/numpy_simulator.py \ + pyquil/operator_estimation.py pyquil/parser.py pyquil/paulis.py pyquil/pyqvm.py \ pyquil/quil.py pyquil/quilatom.py pyquil/quilbase.py pyquil/reference_simulator.py \ - pyquil/unitary_tools.py pyquil/latex pyquil/simulation pyquil/experiment pyquil/device + pyquil/unitary_tools.py pyquil/version.py pyquil/wavefunction.py \ + pyquil/device pyquil/experiment pyquil/latex pyquil/simulation .PHONY: check-style check-style: diff --git a/mypy.ini b/mypy.ini index 0a2938f7f..5a14c48a1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -19,14 +19,22 @@ warn_unused_ignores = True warn_return_any = True no_implicit_reexport = True -# Ignore errors in generated parser files -[mypy-pyquil._parser.gen3.*] +# Ignore errors in all parser-related files +[mypy-pyquil._parser.*] ignore_errors = True # Ignore errors in vendored third-party libraries [mypy-pyquil.external.*] ignore_errors = True -# Ignore errors in test files +# Ignore errors in all test files [mypy-*/tests/*] ignore_errors = True + +# Ignore errors in the conftest.py file +[mypy-conftest] +ignore_errors = True + +# Ignore errors in the pyquil/magic.py file +[mypy-pyquil.magic] +ignore_errors = True diff --git a/pyquil/operator_estimation.py b/pyquil/operator_estimation.py index cb2bd6f0d..c087f1b7c 100644 --- a/pyquil/operator_estimation.py +++ b/pyquil/operator_estimation.py @@ -1,26 +1,36 @@ import logging import warnings from math import pi -from typing import Callable, Dict, List, Union, Tuple, Optional +from typing import Callable, Generator, List, Mapping, Union, Tuple, Optional, cast import numpy as np -from pyquil import Program from pyquil.api import QuantumComputer +from pyquil.quil import Program +from pyquil.quilatom import QubitDesignator # import the full public API of the pyquil experiment module -from pyquil.experiment import ( +from pyquil.experiment._group import ( + _max_weight_state, + _max_weight_operator, + construct_tpb_graph, + group_settings as group_experiments, + group_settings_clique_removal as group_experiments_clique_removal, + group_settings_greedy as group_experiments_greedy, +) +from pyquil.experiment._main import ( + OperatorEncoder, + TomographyExperiment, +) +from pyquil.experiment._result import ExperimentResult, ratio_variance +from pyquil.experiment._setting import ( _OneQState, _pauli_to_product_state, - ExperimentResult, ExperimentSetting, - OperatorEncoder, SIC0, SIC1, SIC2, SIC3, - SymmetrizationLevel, - TomographyExperiment, TensorProductState, minusX, minusY, @@ -28,26 +38,16 @@ plusX, plusY, plusZ, - read_json, - to_json, zeros_state, ) -from pyquil.experiment._group import ( - _max_weight_state, - _max_weight_operator, - construct_tpb_graph, - group_settings as group_experiments, - group_settings_clique_removal as group_experiments_clique_removal, - group_settings_greedy as group_experiments_greedy, -) -from pyquil.experiment._result import ratio_variance +from pyquil.experiment._symmetrization import SymmetrizationLevel from pyquil.gates import RESET, RX, RY, RZ, X from pyquil.paulis import is_identity log = logging.getLogger(__name__) -def _one_q_sic_prep(index, qubit): +def _one_q_sic_prep(index: int, qubit: QubitDesignator) -> Program: """Prepare the index-th SIC basis state.""" if index == 0: return Program() @@ -67,7 +67,7 @@ def _one_q_sic_prep(index, qubit): raise ValueError(f"Bad SIC index: {index}") -def _one_q_pauli_prep(label, index, qubit): +def _one_q_pauli_prep(label: str, index: int, qubit: QubitDesignator) -> Program: """Prepare the index-th eigenstate of the pauli operator given by label.""" if index not in [0, 1]: raise ValueError(f"Bad Pauli index: {index}") @@ -93,7 +93,7 @@ def _one_q_pauli_prep(label, index, qubit): raise ValueError(f"Bad Pauli label: {label}") -def _one_q_state_prep(oneq_state: _OneQState): +def _one_q_state_prep(oneq_state: _OneQState) -> Program: """Prepare a one qubit state. Either SIC[0-3], X[0-1], Y[0-1], or Z[0-1]. @@ -107,7 +107,7 @@ def _one_q_state_prep(oneq_state: _OneQState): raise ValueError(f"Bad state label: {label}") -def _local_pauli_eig_meas(op, idx): +def _local_pauli_eig_meas(op: str, idx: QubitDesignator) -> Program: """ Generate gate sequence to measure in the eigenbasis of a Pauli operator, assuming we are only able to measure in the Z eigenbasis. (Note: The unitary operations of this @@ -175,11 +175,12 @@ def _generate_experiment_programs( "so that groups of parallel settings have compatible observables." ) for qubit, op_str in max_weight_out_op: + assert isinstance(qubit, int) total_prog += _local_pauli_eig_meas(op_str, qubit) programs.append(total_prog) - meas_qubits.append(max_weight_out_op.get_qubits()) + meas_qubits.append(cast(List[int], max_weight_out_op.get_qubits())) return programs, meas_qubits @@ -192,7 +193,7 @@ def measure_observables( symmetrize_readout: Optional[Union[int, str]] = "None", calibrate_readout: Optional[str] = "plus-eig", readout_symmetrize: Optional[str] = None, -): +) -> Generator[ExperimentResult, None, None]: """ Measure all the observables in a TomographyExperiment. @@ -305,7 +306,8 @@ def measure_observables( # either the first column, second column, or both and multiplying along the row. for setting in settings: # Get the term's coefficient so we can multiply it in later. - coeff = complex(setting.out_operator.coefficient) + coeff = setting.out_operator.coefficient + assert isinstance(coeff, complex) if not np.isclose(coeff.imag, 0): raise ValueError(f"{setting}'s out_operator has a complex coefficient.") coeff = coeff.real @@ -328,7 +330,7 @@ def measure_observables( # Obtain calibration program calibr_prog = _calibration_program(qc, tomo_experiment, setting) calibr_qubs = setting.out_operator.get_qubits() - calibr_qub_dict = {q: idx for idx, q in enumerate(calibr_qubs)} + calibr_qub_dict = {cast(int, q): idx for idx, q in enumerate(calibr_qubs)} # Perform symmetrization on the calibration program calibr_results = qc.run_symmetrized_readout( @@ -389,11 +391,11 @@ def _ops_bool_to_prog(ops_bool: Tuple[bool], qubits: List[int]) -> Program: def _stats_from_measurements( bs_results: np.ndarray, - qubit_index_map: Dict, + qubit_index_map: Mapping[int, int], setting: ExperimentSetting, n_shots: int, coeff: float = 1.0, -) -> Tuple[float, float]: +) -> Tuple[np.ndarray, np.ndarray]: """ :param bs_results: results from running `qc.run` :param qubit_index_map: dict mapping qubit to classical register index @@ -403,7 +405,7 @@ def _stats_from_measurements( :return: tuple specifying (mean, variance) """ # Identify classical register indices to select - idxs = [qubit_index_map[q] for q, _ in setting.out_operator] + idxs = [qubit_index_map[cast(int, q)] for q, _ in setting.out_operator] # Pick columns corresponding to qubits with a non-identity out_operation obs_strings = bs_results[:, idxs] # Transform bits to eigenvalues; ie (+1, -1) @@ -444,9 +446,11 @@ def _calibration_program( calibr_prog += kraus_instructions # Prepare the +1 eigenstate for the out operator for q, op in setting.out_operator.operations_as_set(): + assert isinstance(q, int) calibr_prog += _one_q_pauli_prep(label=op, index=0, qubit=q) # Measure the out operator in this state for q, op in setting.out_operator.operations_as_set(): + assert isinstance(q, int) calibr_prog += _local_pauli_eig_meas(op, q) return calibr_prog diff --git a/pyquil/parser.py b/pyquil/parser.py index 139250b10..8dda0791f 100644 --- a/pyquil/parser.py +++ b/pyquil/parser.py @@ -16,12 +16,14 @@ """ Module for parsing Quil programs from text into PyQuil objects """ -from pyquil.quil import Program +from typing import List from pyquil._parser.PyQuilListener import run_parser +from pyquil.quil import Program +from pyquil.quilbase import AbstractInstruction -def parse_program(quil): +def parse_program(quil: str) -> Program: """ Parse a raw Quil program and return a PyQuil program. @@ -31,7 +33,7 @@ def parse_program(quil): return Program(parse(quil)) -def parse(quil): +def parse(quil: str) -> List[AbstractInstruction]: """ Parse a raw Quil program and return a corresponding list of PyQuil objects. diff --git a/pyquil/paulis.py b/pyquil/paulis.py index 242414715..e2c868859 100644 --- a/pyquil/paulis.py +++ b/pyquil/paulis.py @@ -34,9 +34,15 @@ Sequence, Tuple, Union, + cast, ) -from pyquil.quilatom import QubitPlaceholder, FormalArgument, Expression, ExpressionDesignator +from pyquil.quilatom import ( + QubitPlaceholder, + FormalArgument, + Expression, + ExpressionDesignator, +) from .quil import Program from .gates import H, RZ, RX, CNOT, X, PHASE, QUANTUM_GATES @@ -113,7 +119,7 @@ def __init__(self, *args: object, **kwargs: object): """ -def _valid_qubit(index: Union[PauliTargetDesignator, QubitPlaceholder]) -> bool: +def _valid_qubit(index: Optional[Union[PauliTargetDesignator, QubitPlaceholder]]) -> bool: return ( (isinstance(index, integer_types) and index >= 0) or isinstance(index, QubitPlaceholder) @@ -126,7 +132,10 @@ class PauliTerm(object): """ def __init__( - self, op: str, index: PauliTargetDesignator, coefficient: ExpressionDesignator = 1.0 + self, + op: str, + index: Optional[PauliTargetDesignator], + coefficient: ExpressionDesignator = 1.0, ): """ Create a new Pauli Term with a Pauli operator at a particular index and a leading coefficient. @@ -142,12 +151,11 @@ def __init__( if op != "I": if not _valid_qubit(index): raise ValueError(f"{index} is not a valid qubit") + assert index is not None self._ops[index] = op - self.coefficient: Union[complex, Expression] - if isinstance(coefficient, Number): - self.coefficient = complex(coefficient) + self.coefficient: Union[complex, Expression] = complex(coefficient) else: self.coefficient = coefficient @@ -286,8 +294,7 @@ def __mul__(self, term: Union[PauliDesignator, ExpressionDesignator]) -> PauliDe new_term = new_term._multiply_factor(op, index) return term_with_coeff(new_term, new_term.coefficient * new_coeff) - else: # is a Number - return term_with_coeff(self, self.coefficient * term) + return term_with_coeff(self, self.coefficient * term) def __rmul__(self, other: ExpressionDesignator) -> "PauliTerm": """Multiplies this PauliTerm with another object, probably a number. @@ -295,8 +302,9 @@ def __rmul__(self, other: ExpressionDesignator) -> "PauliTerm": :param other: A number or PauliTerm to multiply by :returns: A new PauliTerm """ - assert isinstance(other, Number) - return self * other + p = self * other + assert isinstance(p, PauliTerm) + return p def __pow__(self, power: int) -> "PauliTerm": """Raises this PauliTerm to power. @@ -313,7 +321,7 @@ def __pow__(self, power: int) -> "PauliTerm": result = ID() for _ in range(power): - result *= self + result = cast(PauliTerm, result * self) return result def __add__(self, other: Union[PauliDesignator, ExpressionDesignator]) -> "PauliSum": @@ -330,24 +338,23 @@ def __add__(self, other: Union[PauliDesignator, ExpressionDesignator]) -> "Pauli else: # is a Number return self + PauliTerm("I", 0, other) - def __radd__(self, other: ExpressionDesignator) -> "PauliTerm": + def __radd__(self, other: ExpressionDesignator) -> "PauliSum": """Adds this PauliTerm with a Number. :param other: A Number - :returns: A new PauliTerm + :returns: A new PauliSum """ - assert isinstance(other, Number) return PauliTerm("I", 0, other) + self - def __sub__(self, other: Union["PauliTerm", Number]) -> "PauliSum": + def __sub__(self, other: Union["PauliTerm", ExpressionDesignator]) -> "PauliSum": """Subtracts a PauliTerm from this one. - :param other: A PauliTerm object or a Number + :param other: A PauliTerm object, a number, or an Expression :returns: A PauliSum object representing the difference of this PauliTerm and term """ return self + -1.0 * other - def __rsub__(self, other: Union["PauliTerm", Number]) -> "PauliSum": + def __rsub__(self, other: Union["PauliTerm", ExpressionDesignator]) -> "PauliSum": """Subtracts this PauliTerm from a Number or PauliTerm. :param other: A PauliTerm object or a Number @@ -430,9 +437,8 @@ def from_compact_str(cls, str_pauli_term: str) -> "PauliTerm": # parse the coefficient into either a float or complex str_coef = str_coef.replace(" ", "") - coef: Union[float, complex] try: - coef = float(str_coef) + coef: Union[float, complex] = float(str_coef) except ValueError: try: coef = complex(str_coef) @@ -441,6 +447,7 @@ def from_compact_str(cls, str_pauli_term: str) -> "PauliTerm": op = sI() * coef if str_op == "I": + assert isinstance(op, PauliTerm) return op # parse the operator @@ -453,6 +460,7 @@ def from_compact_str(cls, str_pauli_term: str) -> "PauliTerm": for factor in re.finditer(r"([XYZ])(\d+)", str_op): op *= cls(factor.group(1), int(factor.group(2))) + assert isinstance(op, PauliTerm) return op def pauli_string(self, qubits: Optional[Iterable[int]] = None) -> str: @@ -483,7 +491,7 @@ def pauli_string(self, qubits: Optional[Iterable[int]] = None) -> str: "Please provide a list of qubits when using PauliTerm.pauli_string", DeprecationWarning, ) - qubits = self.get_qubits() + qubits = cast(List[int], self.get_qubits()) assert qubits is not None return "".join(self[q] for q in qubits) @@ -626,16 +634,19 @@ def __mul__(self, other: Union[PauliDesignator, ExpressionDesignator]) -> "Pauli :param other: a PauliSum, PauliTerm or Number object :return: A new PauliSum object given by the multiplication. """ - if not isinstance(other, (Number, PauliTerm, PauliSum)): + if not isinstance(other, (Expression, int, complex, float, PauliTerm, PauliSum)): raise ValueError( "Cannot multiply PauliSum by term that is not a Number, PauliTerm, or PauliSum" ) - elif isinstance(other, PauliSum): - other_terms = other.terms + + other_terms: List[Union[PauliTerm, ExpressionDesignator]] = [] + if isinstance(other, PauliSum): + other_terms += other.terms else: - other_terms = [other] + other_terms += [other] + new_terms = [lterm * rterm for lterm, rterm in product(self.terms, other_terms)] - new_sum = PauliSum(new_terms) + new_sum = PauliSum(cast(List[PauliTerm], new_terms)) return new_sum.simplify() def __rmul__(self, other: ExpressionDesignator) -> "PauliSum": @@ -685,11 +696,13 @@ def __add__(self, other: Union[PauliDesignator, ExpressionDesignator]) -> "Pauli :return: A new PauliSum object given by the addition. """ if isinstance(other, PauliTerm): - other = PauliSum([other]) - elif isinstance(other, Number): - other = PauliSum([other * ID()]) + other_sum = PauliSum([other]) + elif isinstance(other, (Expression, int, complex, float)): + other_sum = PauliSum([other * ID()]) + else: + other_sum = other new_terms = [term.copy() for term in self.terms] - new_terms.extend(other.terms) + new_terms.extend(other_sum.terms) new_sum = PauliSum(new_terms) return new_sum.simplify() @@ -730,7 +743,10 @@ def get_qubits(self) -> List[PauliTargetDesignator]: :returns: A list of all the qubits in the sum of terms. """ - return list(set().union(*[term.get_qubits() for term in self.terms])) + all_qubits = [] + for term in self.terms: + all_qubits.extend(term.get_qubits()) + return list(set(all_qubits)) def simplify(self) -> "PauliSum": """ @@ -899,6 +915,8 @@ def exponential_map(term: PauliTerm) -> Callable[[float], Program]: :param term: A pauli term to exponentiate :returns: A function that takes an angle parameter and returns a program. """ + assert isinstance(term.coefficient, (float, complex)) + if not np.isclose(np.imag(term.coefficient), 0.0): raise TypeError("PauliTerm coefficient must be real") @@ -964,14 +982,13 @@ def reverse_hack(p: Program) -> Program: highest_target_index = None for index, op in pauli_term: + assert isinstance(index, (int, QubitPlaceholder)) if "X" == op: change_to_z_basis.inst(H(index)) change_to_original_basis.inst(H(index)) - elif "Y" == op: change_to_z_basis.inst(RX(np.pi / 2.0, index)) change_to_original_basis.inst(RX(-np.pi / 2.0, index)) - elif "I" == op: continue @@ -984,6 +1001,7 @@ def reverse_hack(p: Program) -> Program: # building rotation circuit quil_prog += change_to_z_basis quil_prog += cnot_seq + assert isinstance(pauli_term.coefficient, (float, complex)) and highest_target_index is not None quil_prog.inst(RZ(2.0 * pauli_term.coefficient * param, highest_target_index)) quil_prog += reverse_hack(cnot_seq) quil_prog += change_to_original_basis @@ -1053,8 +1071,10 @@ def is_zero(pauli_object: PauliDesignator) -> bool: :returns: True if PauliTerm is zero, False otherwise """ if isinstance(pauli_object, PauliTerm): - return np.isclose(pauli_object.coefficient, 0) + assert isinstance(pauli_object.coefficient, (float, complex)) + return bool(np.isclose(pauli_object.coefficient, 0)) elif isinstance(pauli_object, PauliSum): + assert isinstance(pauli_object.terms[0].coefficient, (float, complex)) return len(pauli_object.terms) == 1 and np.isclose(pauli_object.terms[0].coefficient, 0) else: raise TypeError("is_zero only checks PauliTerms and PauliSum objects!") diff --git a/pyquil/pyqvm.py b/pyquil/pyqvm.py index a6a25701b..66f709586 100644 --- a/pyquil/pyqvm.py +++ b/pyquil/pyqvm.py @@ -16,7 +16,7 @@ import logging import warnings from abc import ABC, abstractmethod -from typing import Dict, List, Sequence, Type, Union +from typing import Dict, List, Optional, Sequence, Type, Union import numpy as np from numpy.random.mtrand import RandomState @@ -26,7 +26,7 @@ from pyquil.api._compiler import _extract_program_from_pyquil_executable_response from pyquil.paulis import PauliTerm, PauliSum from pyquil.quil import Program -from pyquil.quilatom import Label, MemoryReference +from pyquil.quilatom import Label, LabelPlaceholder, MemoryReference from pyquil.quilbase import ( Gate, Measurement, @@ -137,7 +137,7 @@ def reset(self) -> "AbstractQuantumSimulator": """ @abstractmethod - def sample_bitstrings(self, n_samples) -> np.ndarray: + def sample_bitstrings(self, n_samples: int) -> np.ndarray: """ Sample bitstrings from the current state. @@ -164,10 +164,10 @@ def do_post_gate_noise( class PyQVM(QAM): def __init__( self, - n_qubits, - quantum_simulator_type: Type[AbstractQuantumSimulator] = None, - seed=None, - post_gate_noise_probabilities: Dict[str, float] = None, + n_qubits: int, + quantum_simulator_type: Optional[Type[AbstractQuantumSimulator]] = None, + seed: Optional[int] = None, + post_gate_noise_probabilities: Optional[Dict[str, float]] = None, ): """ PyQuil's built-in Quil virtual machine. @@ -199,26 +199,26 @@ def __init__( quantum_simulator_type = ReferenceDensitySimulator self.n_qubits = n_qubits - self.ram = {} + self.ram: Dict[str, List[Union[int, float]]] = {} if post_gate_noise_probabilities is None: post_gate_noise_probabilities = {} self.post_gate_noise_probabilities = post_gate_noise_probabilities - self.program = None # type: Program - self.program_counter = None # type: int - self.defined_gates = dict() # type: Dict[str, np.ndarray] + self.program: Optional[Program] = None + self.program_counter: int = 0 + self.defined_gates: Dict[str, np.ndarray] = dict() # private implementation details - self._qubit_to_ram = None # type: Dict[int, int] - self._ro_size = None # type :int - self._memory_results = None # type: Dict[str, np.ndarray] + self._qubit_to_ram: Optional[Dict[int, int]] = None + self._ro_size: Optional[int] = None + self._memory_results: Optional[Dict[str, np.ndarray]] = None self.rs = np.random.RandomState(seed=seed) self.wf_simulator = quantum_simulator_type(n_qubits=n_qubits, rs=self.rs) self._last_measure_program_loc = None - def load(self, executable): + def load(self, executable: Union[Program, PyQuilExecutableResponse]) -> "PyQVM": if isinstance(executable, PyQuilExecutableResponse): program = _extract_program_from_pyquil_executable_response(executable) else: @@ -240,8 +240,9 @@ def load(self, executable): self.status = "loaded" return self - def _extract_defined_gates(self): + def _extract_defined_gates(self) -> None: self.defined_gates = dict() + assert self.program is not None for dg in self.program.defined_gates: if dg.parameters is not None and len(dg.parameters) > 0: raise NotImplementedError("PyQVM does not support parameterized DEFGATEs") @@ -251,16 +252,17 @@ def _extract_defined_gates(self): ) self.defined_gates[dg.name] = dg.matrix - def write_memory(self, *, region_name: str, offset: int = 0, value=None): + def write_memory(self, *, region_name: str, offset: int = 0, value: int = 0) -> "PyQVM": assert self.status in ["loaded", "done"] assert region_name != "ro" self.ram[region_name][offset] = value return self - def run(self): + def run(self) -> "PyQVM": self.status = "running" self._memory_results = {} + assert self.program is not None for _ in range(self.program.num_shots): self.wf_simulator.reset() self._execute_program() @@ -273,15 +275,16 @@ def run(self): return self - def wait(self): + def wait(self) -> "PyQVM": assert self.status == "running" self.status = "done" return self - def read_memory(self, *, region_name: str): + def read_memory(self, *, region_name: str) -> np.ndarray: + assert self._memory_results is not None return np.asarray(self._memory_results[region_name]) - def find_label(self, label: Label): + def find_label(self, label: Union[Label, LabelPlaceholder]) -> int: """ Helper function that iterates over the program and looks for a JumpTarget that has a Label matching the input label. @@ -289,6 +292,7 @@ def find_label(self, label: Label): :param label: Label object to search for in program :return: Program index where ``label`` is found """ + assert self.program is not None for index, action in enumerate(self.program): if isinstance(action, JumpTarget): if label == action.label: @@ -296,7 +300,7 @@ def find_label(self, label: Label): raise RuntimeError("Improper program - Jump Target not found in the input program!") - def transition(self): + def transition(self) -> bool: """ Implements a QAM-like transition. @@ -306,6 +310,7 @@ def transition(self): :return: whether the QAM should halt after this transition. """ + assert self.program is not None instruction = self.program[self.program_counter] if isinstance(instruction, Gate): @@ -326,8 +331,9 @@ def transition(self): elif isinstance(instruction, Measurement): measured_val = self.wf_simulator.do_measurement(qubit=instruction.qubit.index) - x = instruction.classical_reg # type: MemoryReference - self.ram[x.name][x.offset] = measured_val + meas_reg: Optional[MemoryReference] = instruction.classical_reg + assert meas_reg is not None + self.ram[meas_reg.name][meas_reg.offset] = measured_val self.program_counter += 1 elif isinstance(instruction, Declare): @@ -353,8 +359,9 @@ def transition(self): elif isinstance(instruction, JumpConditional): # JumpConditional; check classical reg - x = instruction.condition # type: MemoryReference - cond = self.ram[x.name][x.offset] + jump_reg: Optional[MemoryReference] = instruction.condition + assert jump_reg is not None + cond = self.ram[jump_reg.name][jump_reg.offset] if not isinstance(cond, (bool, np.bool, np.int8)): raise ValueError( "{} requires a data type of BIT; not {}".format(instruction.op, type(cond)) @@ -376,7 +383,7 @@ def transition(self): elif isinstance(instruction, UnaryClassicalInstruction): # UnaryClassicalInstruction; set classical reg - target = instruction.target # type:MemoryReference + target = instruction.target old = self.ram[target.name][target.offset] if isinstance(instruction, ClassicalNeg): if not isinstance(old, (int, float, np.int, np.float)): @@ -394,19 +401,22 @@ def transition(self): self.program_counter += 1 elif isinstance(instruction, (LogicalBinaryOp, ArithmeticBinaryOp, ClassicalMove)): - left_ind = instruction.left # type: MemoryReference + left_ind = instruction.left left_val = self.ram[left_ind.name][left_ind.offset] if isinstance(instruction.right, MemoryReference): - right_ind = instruction.right # type: MemoryReference + right_ind = instruction.right right_val = self.ram[right_ind.name][right_ind.offset] else: right_val = instruction.right if isinstance(instruction, ClassicalAnd): - new_val = left_val & right_val + assert isinstance(left_val, int) and isinstance(right_val, int) + new_val: Union[int, float] = left_val & right_val elif isinstance(instruction, ClassicalInclusiveOr): + assert isinstance(left_val, int) and isinstance(right_val, int) new_val = left_val | right_val elif isinstance(instruction, ClassicalExclusiveOr): + assert isinstance(left_val, int) and isinstance(right_val, int) new_val = left_val ^ right_val elif isinstance(instruction, ClassicalAdd): new_val = left_val + right_val @@ -424,12 +434,14 @@ def transition(self): self.program_counter += 1 elif isinstance(instruction, ClassicalExchange): - left_ind = instruction.left # type: MemoryReference - right_ind = instruction.right # type: MemoryReference - - tmp = self.ram[left_ind.name][left_ind.offset] - self.ram[left_ind.name][left_ind.offset] = self.ram[right_ind.name][right_ind.offset] - self.ram[right_ind.name][right_ind.offset] = tmp + left_ind_ex = instruction.left + right_ind_ex = instruction.right + + tmp = self.ram[left_ind_ex.name][left_ind_ex.offset] + self.ram[left_ind_ex.name][left_ind_ex.offset] = self.ram[right_ind_ex.name][ + right_ind_ex.offset + ] + self.ram[right_ind_ex.name][right_ind_ex.offset] = tmp self.program_counter += 1 elif isinstance(instruction, Reset): @@ -437,9 +449,7 @@ def transition(self): self.program_counter += 1 elif isinstance(instruction, ResetQubit): - # TODO raise NotImplementedError("Need to implement in wf simulator") - self.program_counter += 1 elif isinstance(instruction, Wait): warnings.warn("WAIT does nothing for a noiseless simulator") @@ -464,18 +474,20 @@ def transition(self): raise ValueError("Unsupported instruction type: {}".format(instruction)) # return HALTED (i.e. program_counter is end of program) + assert self.program is not None return self.program_counter == len(self.program) - def _execute_program(self): + def _execute_program(self) -> "PyQVM": self.program_counter = 0 + assert self.program is not None halted = len(self.program) == 0 while not halted: halted = self.transition() return self - def execute(self, program: Program): + def execute(self, program: Program) -> "PyQVM": """ Execute one outer loop of a program on the QVM. diff --git a/pyquil/tests/test_wavefunction.py b/pyquil/tests/test_wavefunction.py index 636ceec53..df859d2cc 100644 --- a/pyquil/tests/test_wavefunction.py +++ b/pyquil/tests/test_wavefunction.py @@ -5,7 +5,6 @@ from pyquil.wavefunction import ( get_bitstring_from_index, Wavefunction, - _round_to_next_multiple, _octet_bits, ) @@ -48,16 +47,6 @@ def test_ground_state(): assert ground.amplitudes[0] == 1.0 -def test_rounding(): - for i in range(8): - if 0 == i % 8: - assert i == _round_to_next_multiple(i, 8) - else: - assert 8 == _round_to_next_multiple(i, 8) - assert 16 == _round_to_next_multiple(i + 8, 8) - assert 24 == _round_to_next_multiple(i + 16, 8) - - def test_octet_bits(): assert [0, 0, 0, 0, 0, 0, 0, 0] == _octet_bits(0b0) assert [1, 0, 0, 0, 0, 0, 0, 0] == _octet_bits(0b1) diff --git a/pyquil/wavefunction.py b/pyquil/wavefunction.py index f15398638..b56ae5de0 100644 --- a/pyquil/wavefunction.py +++ b/pyquil/wavefunction.py @@ -16,9 +16,10 @@ """ Module containing the Wavefunction object and methods for working with wavefunctions. """ +import itertools import struct import warnings -import itertools +from typing import Dict, Iterator, List, Optional, Sequence, cast import numpy as np @@ -40,7 +41,7 @@ class Wavefunction(object): `. """ - def __init__(self, amplitude_vector): + def __init__(self, amplitude_vector: np.ndarray): """ Initializes a wavefunction @@ -58,29 +59,27 @@ def __init__(self, amplitude_vector): ) @staticmethod - def ground(qubit_num): + def ground(qubit_num: int) -> "Wavefunction": warnings.warn("ground() has been deprecated in favor of zeros()", stacklevel=2) return Wavefunction.zeros(qubit_num) @staticmethod - def zeros(qubit_num): + def zeros(qubit_num: int) -> "Wavefunction": """ Constructs the groundstate wavefunction for a given number of qubits. - :param int qubit_num: + :param qubit_num: :return: A Wavefunction in the ground state - :rtype: Wavefunction """ amplitude_vector = np.zeros(2 ** qubit_num) amplitude_vector[0] = 1.0 return Wavefunction(amplitude_vector) @staticmethod - def from_bit_packed_string(coef_string): + def from_bit_packed_string(coef_string: bytes) -> "Wavefunction": """ From a bit packed string, unpacks to get the wavefunction - :param bytes coef_string: - :return: + :param coef_string: """ num_octets = len(coef_string) @@ -95,26 +94,26 @@ def from_bit_packed_string(coef_string): return Wavefunction(wf) - def __len__(self): + def __len__(self) -> int: return len(self.amplitudes).bit_length() - 1 - def __iter__(self): - return self.amplitudes.__iter__() + def __iter__(self) -> Iterator[complex]: + return cast(Iterator[complex], self.amplitudes.__iter__()) - def __getitem__(self, index): - return self.amplitudes[index] + def __getitem__(self, index: int) -> complex: + return cast(complex, self.amplitudes[index]) - def __setitem__(self, key, value): + def __setitem__(self, key: int, value: complex) -> None: self.amplitudes[key] = value - def __str__(self): + def __str__(self) -> str: return self.pretty_print(decimal_digits=10) - def probabilities(self): + def probabilities(self) -> np.ndarray: """Returns an array of probabilities in lexicographical order""" return np.abs(self.amplitudes) ** 2 - def get_outcome_probs(self): + def get_outcome_probs(self) -> Dict[str, float]: """ Parses a wavefunction (array of complex amplitudes) and returns a dictionary of outcomes and associated probabilities. @@ -129,14 +128,15 @@ def get_outcome_probs(self): outcome_dict[outcome] = abs(amplitude) ** 2 return outcome_dict - def pretty_print_probabilities(self, decimal_digits=2): + def pretty_print_probabilities(self, decimal_digits: int = 2) -> Dict[str, float]: """ + TODO: This doesn't seem like it is named correctly... + Prints outcome probabilities, ignoring all outcomes with approximately zero probabilities (up to a certain number of decimal digits) and rounding the probabilities to decimal_digits. :param int decimal_digits: The number of digits to truncate to. :return: A dict with outcomes as keys and probabilities as values. - :rtype: dict """ outcome_dict = {} qubit_num = len(self) @@ -147,15 +147,14 @@ def pretty_print_probabilities(self, decimal_digits=2): outcome_dict[outcome] = prob return outcome_dict - def pretty_print(self, decimal_digits=2): + def pretty_print(self, decimal_digits: int = 2) -> str: """ Returns a string repr of the wavefunction, ignoring all outcomes with approximately zero amplitude (up to a certain number of decimal digits) and rounding the amplitudes to decimal_digits. :param int decimal_digits: The number of digits to truncate to. - :return: A dict with outcomes as keys and complex amplitudes as values. - :rtype: str + :return: A string representation of the wavefunction. """ outcome_dict = {} qubit_num = len(self) @@ -172,12 +171,13 @@ def pretty_print(self, decimal_digits=2): pp_string = pp_string[:-3] # remove the dangling + if it is there return pp_string - def plot(self, qubit_subset=None): + def plot(self, qubit_subset: Optional[Sequence[int]] = None) -> None: """ + TODO: calling this will error because of matplotlib + Plots a bar chart with bitstring on the x axis and probability on the y axis. - :param list qubit_subset: Optional parameter used for plotting a subset of the Hilbert - space. + :param qubit_subset: Optional parameter used for plotting a subset of the Hilbert space. """ import matplotlib.pyplot as plt @@ -197,7 +197,7 @@ def plot(self, qubit_subset=None): plt.xticks(range(len(prob_dict)), prob_dict.keys()) plt.show() - def sample_bitstrings(self, n_samples): + def sample_bitstrings(self, n_samples: int) -> np.ndarray: """ Sample bitstrings from the distribution defined by the wavefunction. @@ -210,7 +210,7 @@ def sample_bitstrings(self, n_samples): return bitstrings -def get_bitstring_from_index(index, qubit_num): +def get_bitstring_from_index(index: int, qubit_num: int) -> str: """ Returns the bitstring in lexical order that corresponds to the given index in 0 to 2^(qubit_num) :param int index: @@ -223,24 +223,12 @@ def get_bitstring_from_index(index, qubit_num): return bin(index)[2:].rjust(qubit_num, "0") -def _round_to_next_multiple(n, m): - """ - Round up the the next multiple. - - :param n: The number to round up. - :param m: The multiple. - :return: The rounded number - """ - return n if n % m == 0 else n + m - n % m - - -def _octet_bits(o): +def _octet_bits(o: int) -> List[int]: """ Get the bits of an octet. :param o: The octets. :return: The bits as a list in LSB-to-MSB order. - :rtype: list """ if not isinstance(o, int): raise TypeError("o should be an int")