From 6b4cfcd3fc46856535ec31478e247cd57f8275d4 Mon Sep 17 00:00:00 2001 From: ikkoham Date: Wed, 28 Oct 2020 11:26:26 +0900 Subject: [PATCH] separate StateFn to Factory and BaseClass --- .../converters/pauli_basis_change.py | 3 +- .../gradients/circuit_gradients/lin_comb.py | 15 +- .../circuit_gradients/param_shift.py | 3 +- qiskit/aqua/operators/gradients/hessian.py | 3 +- .../operators/state_fns/circuit_state_fn.py | 3 +- .../aqua/operators/state_fns/dict_state_fn.py | 3 +- .../operators/state_fns/operator_state_fn.py | 3 +- qiskit/aqua/operators/state_fns/state_fn.py | 373 ++-------------- .../aqua/operators/state_fns/state_fn_base.py | 399 ++++++++++++++++++ .../operators/state_fns/vector_state_fn.py | 3 +- 10 files changed, 463 insertions(+), 345 deletions(-) create mode 100644 qiskit/aqua/operators/state_fns/state_fn_base.py diff --git a/qiskit/aqua/operators/converters/pauli_basis_change.py b/qiskit/aqua/operators/converters/pauli_basis_change.py index 526a1f4757..de69f1724f 100644 --- a/qiskit/aqua/operators/converters/pauli_basis_change.py +++ b/qiskit/aqua/operators/converters/pauli_basis_change.py @@ -27,6 +27,7 @@ from ..list_ops.list_op import ListOp from ..list_ops.composed_op import ComposedOp from ..state_fns.state_fn import StateFn +from ..state_fns.state_fn_base import StateFnBase from ..operator_globals import H, S, I from .converter_base import ConverterBase @@ -134,7 +135,7 @@ def convert(self, operator: OperatorBase) -> OperatorBase: if isinstance(operator, (Pauli, PrimitiveOp)): cob_instr_op, dest_pauli_op = self.get_cob_circuit(operator) return self._replacement_fn(cob_instr_op, dest_pauli_op) # type: ignore - if isinstance(operator, StateFn) and 'Pauli' in operator.primitive_strings(): + if isinstance(operator, StateFnBase) and 'Pauli' in operator.primitive_strings(): # If the StateFn/Meas only contains a Pauli, use it directly. if isinstance(operator.primitive, PrimitiveOp): cob_instr_op, dest_pauli_op = self.get_cob_circuit(operator.primitive) diff --git a/qiskit/aqua/operators/gradients/circuit_gradients/lin_comb.py b/qiskit/aqua/operators/gradients/circuit_gradients/lin_comb.py index 5b37184d76..2c252636ef 100644 --- a/qiskit/aqua/operators/gradients/circuit_gradients/lin_comb.py +++ b/qiskit/aqua/operators/gradients/circuit_gradients/lin_comb.py @@ -24,6 +24,7 @@ from qiskit.aqua.operators.operator_globals import Z, I, One, Zero from qiskit.aqua.operators.primitive_ops.primitive_op import PrimitiveOp from qiskit.aqua.operators.state_fns import StateFn, CircuitStateFn, DictStateFn, VectorStateFn +from qiskit.aqua.operators.state_fns.state_fn_base import StateFnBase from qiskit.circuit import Gate, Instruction, Qubit from qiskit.circuit import (QuantumCircuit, QuantumRegister, ParameterVector, ParameterExpression) @@ -115,14 +116,14 @@ def _prepare_operator(self, if isinstance(operator, ComposedOp): # Get the measurement and the state operator - if not isinstance(operator[0], StateFn) or not operator[0].is_measurement: + if not isinstance(operator[0], StateFnBase) or not operator[0].is_measurement: raise ValueError("The given operator does not correspond to an expectation value") - if not isinstance(operator[-1], StateFn) or operator[-1].is_measurement: + if not isinstance(operator[-1], StateFnBase) or operator[-1].is_measurement: raise ValueError("The given operator does not correspond to an expectation value") if operator[0].is_measurement: if len(operator.oplist) == 2: state_op = operator[1] - if not isinstance(state_op, StateFn): + if not isinstance(state_op, StateFnBase): raise TypeError('The StateFn representing the quantum state could not be' 'extracted.') if isinstance(params, (ParameterExpression, ParameterVector)) or \ @@ -144,7 +145,7 @@ def _prepare_operator(self, else: state_op = deepcopy(operator) state_op.oplist.pop(0) - if not isinstance(state_op, StateFn): + if not isinstance(state_op, StateFnBase): raise TypeError('The StateFn representing the quantum state could not be' 'extracted.') @@ -169,7 +170,7 @@ def _prepare_operator(self, return operator.traverse(partial(self._prepare_operator, params=params)) elif isinstance(operator, ListOp): return operator.traverse(partial(self._prepare_operator, params=params)) - elif isinstance(operator, StateFn): + elif isinstance(operator, StateFnBase): if operator.is_measurement: return operator.traverse(partial(self._prepare_operator, params=params)) else: @@ -190,7 +191,7 @@ def _prepare_operator(self, return operator def _gradient_states(self, - state_op: StateFn, + state_op: StateFnBase, meas_op: Optional[OperatorBase] = None, target_params: Optional[ Union[ParameterExpression, ParameterVector, @@ -334,7 +335,7 @@ def _gradient_states(self, return op def _hessian_states(self, - state_op: StateFn, + state_op: StateFnBase, meas_op: Optional[OperatorBase] = None, target_params: Optional[Union[Tuple[ParameterExpression, ParameterExpression], diff --git a/qiskit/aqua/operators/gradients/circuit_gradients/param_shift.py b/qiskit/aqua/operators/gradients/circuit_gradients/param_shift.py index f9e09873a4..dcc12b9fbe 100644 --- a/qiskit/aqua/operators/gradients/circuit_gradients/param_shift.py +++ b/qiskit/aqua/operators/gradients/circuit_gradients/param_shift.py @@ -23,6 +23,7 @@ from qiskit.aqua.operators import (OperatorBase, StateFn, Zero, One, CircuitStateFn, CircuitOp) from qiskit.aqua.operators import SummedOp, ListOp, ComposedOp, DictStateFn, VectorStateFn +from qiskit.aqua.operators.state_fns.state_fn_base import StateFnBase from qiskit.aqua.operators.gradients.circuit_gradients.circuit_gradient \ import CircuitGradient from qiskit.aqua.operators.gradients.derivative_base import DerivativeBase @@ -210,7 +211,7 @@ def _parameter_shift(self, if isinstance(operator, ComposedOp): shifted_op = shift_constant * (pshift_op - mshift_op) - elif isinstance(operator, StateFn): + elif isinstance(operator, StateFnBase): shifted_op = ListOp( [pshift_op, mshift_op], combo_fn=partial(self._prob_combo_fn, shift_constant=shift_constant)) diff --git a/qiskit/aqua/operators/gradients/hessian.py b/qiskit/aqua/operators/gradients/hessian.py index 0fafa99210..fcf372230c 100644 --- a/qiskit/aqua/operators/gradients/hessian.py +++ b/qiskit/aqua/operators/gradients/hessian.py @@ -17,6 +17,7 @@ import numpy as np from qiskit.aqua.aqua_globals import AquaError from qiskit.aqua.operators import Zero, One, CircuitStateFn, StateFn +from qiskit.aqua.operators.state_fns.state_fn_base import StateFnBase from qiskit.aqua.operators.expectations import PauliExpectation from qiskit.aqua.operators.gradients.gradient import Gradient from qiskit.aqua.operators.gradients.hessian_base import HessianBase @@ -239,7 +240,7 @@ def is_coeff_c(coeff, c): return SummedOp([term1, term2]) - elif isinstance(operator, StateFn): + elif isinstance(operator, StateFnBase): if not operator.is_measurement: from .circuit_gradients import LinComb if isinstance(self.hess_method, LinComb): diff --git a/qiskit/aqua/operators/state_fns/circuit_state_fn.py b/qiskit/aqua/operators/state_fns/circuit_state_fn.py index fec43202ea..c097391fc7 100644 --- a/qiskit/aqua/operators/state_fns/circuit_state_fn.py +++ b/qiskit/aqua/operators/state_fns/circuit_state_fn.py @@ -24,9 +24,10 @@ from ..operator_base import OperatorBase from ..list_ops.summed_op import SummedOp from .state_fn import StateFn +from .state_fn_base import StateFnBase -class CircuitStateFn(StateFn): +class CircuitStateFn(StateFnBase): r""" A class for state functions and measurements which are defined by the action of a QuantumCircuit starting from \|0⟩, and stored using Terra's ``QuantumCircuit`` class. diff --git a/qiskit/aqua/operators/state_fns/dict_state_fn.py b/qiskit/aqua/operators/state_fns/dict_state_fn.py index f944ac1ff6..187cc284af 100644 --- a/qiskit/aqua/operators/state_fns/dict_state_fn.py +++ b/qiskit/aqua/operators/state_fns/dict_state_fn.py @@ -23,11 +23,12 @@ from ..operator_base import OperatorBase from .state_fn import StateFn +from .state_fn_base import StateFnBase from ..list_ops.list_op import ListOp from .vector_state_fn import VectorStateFn -class DictStateFn(StateFn): +class DictStateFn(StateFnBase): """ A class for state functions and measurements which are defined by a lookup table, stored in a dict. """ diff --git a/qiskit/aqua/operators/state_fns/operator_state_fn.py b/qiskit/aqua/operators/state_fns/operator_state_fn.py index b4fc507584..a2f8ec5e4c 100644 --- a/qiskit/aqua/operators/state_fns/operator_state_fn.py +++ b/qiskit/aqua/operators/state_fns/operator_state_fn.py @@ -19,6 +19,7 @@ from ..operator_base import OperatorBase from .state_fn import StateFn +from .state_fn_base import StateFnBase from .vector_state_fn import VectorStateFn from ..list_ops.list_op import ListOp from ..list_ops.summed_op import SummedOp @@ -26,7 +27,7 @@ # pylint: disable=invalid-name -class OperatorStateFn(StateFn): +class OperatorStateFn(StateFnBase): r""" A class for state functions and measurements which are defined by a density Operator, stored using an ``OperatorBase``. diff --git a/qiskit/aqua/operators/state_fns/state_fn.py b/qiskit/aqua/operators/state_fns/state_fn.py index c57718761e..b058724a5b 100644 --- a/qiskit/aqua/operators/state_fns/state_fn.py +++ b/qiskit/aqua/operators/state_fns/state_fn.py @@ -12,21 +12,22 @@ """ StateFn Class """ -from typing import Union, Optional, Callable, Set, Dict, Tuple, List -import numpy as np +from typing import Union +import numpy as np +from qiskit.circuit import Instruction, ParameterExpression from qiskit.quantum_info import Statevector from qiskit.result import Result + from qiskit import QuantumCircuit -from qiskit.circuit import Instruction, ParameterExpression from ..operator_base import OperatorBase -from ..legacy.base_operator import LegacyBaseOperator -class StateFn(OperatorBase): +class StateFn: r""" - A class for representing state functions and measurements. + A class for generating state functions such as + DictStateFn, ListStateFn, OperatorStateFn, VectorStateFn, etc. State functions are defined to be complex functions over a single binary string (as compared to an operator, which is defined as a function over two binary strings, or a @@ -45,18 +46,26 @@ class StateFn(OperatorBase): no requirement of normalization. """ - @staticmethod # pylint: disable=unused-argument - def __new__(cls, - primitive: Union[str, dict, Result, - list, np.ndarray, Statevector, - QuantumCircuit, Instruction, - OperatorBase], - coeff: Union[int, float, complex, ParameterExpression] = 1.0, - is_measurement: bool = False) -> 'StateFn': - """ A factory method to produce the correct type of StateFn subclass + def __new__( + cls, + primitive: Union[ + str, + dict, + Result, + list, + np.ndarray, + Statevector, + QuantumCircuit, + Instruction, + OperatorBase, + ], + coeff: Union[int, float, complex, ParameterExpression] = 1.0, + is_measurement: bool = False, + ) -> "StateFn": + """A factory method to produce the correct type of StateFnBase subclass based on the primitive passed in. Primitive, coeff, and is_measurement arguments - are passed into subclass's init() as-is automatically by new(). + are passed into subclass's init(). Args: primitive: The primitive which defines the behavior of the underlying State function. @@ -69,330 +78,32 @@ def __new__(cls, Raises: TypeError: Unsupported primitive type passed. """ - - # Prevents infinite recursion when subclasses are created - if cls.__name__ != StateFn.__name__: - return super().__new__(cls) - # pylint: disable=cyclic-import,import-outside-toplevel if isinstance(primitive, (str, dict, Result)): from .dict_state_fn import DictStateFn - return DictStateFn.__new__(DictStateFn, primitive) + instance = DictStateFn.__new__(DictStateFn) + instance.__init__(primitive, coeff, is_measurement) + return instance if isinstance(primitive, (list, np.ndarray, Statevector)): from .vector_state_fn import VectorStateFn - return VectorStateFn.__new__(VectorStateFn, primitive) + instance = VectorStateFn.__new__(VectorStateFn) + instance.__init__(primitive, coeff, is_measurement) + return instance if isinstance(primitive, (QuantumCircuit, Instruction)): from .circuit_state_fn import CircuitStateFn - return CircuitStateFn.__new__(CircuitStateFn, primitive) + instance = CircuitStateFn.__new__(CircuitStateFn) + instance.__init__(primitive, coeff, is_measurement) + return instance if isinstance(primitive, OperatorBase): from .operator_state_fn import OperatorStateFn - return OperatorStateFn.__new__(OperatorStateFn, primitive) - - raise TypeError('Unsupported primitive type {} passed into StateFn ' - 'factory constructor'.format(type(primitive))) - - # TODO allow normalization somehow? - def __init__(self, - primitive: Union[str, dict, Result, - list, np.ndarray, Statevector, - QuantumCircuit, Instruction, - OperatorBase] = None, - coeff: Union[int, float, complex, ParameterExpression] = 1.0, - is_measurement: bool = False) -> None: - """ - Args: - primitive: The primitive which defines the behavior of the underlying State function. - coeff: A coefficient by which the state function is multiplied. - is_measurement: Whether the StateFn is a measurement operator - """ - self._primitive = primitive - self._is_measurement = is_measurement - self._coeff = coeff - - @property - def primitive(self): - """ The primitive which defines the behavior of the underlying State function. """ - return self._primitive - - @property - def coeff(self) -> Union[int, float, complex, ParameterExpression]: - """ A coefficient by which the state function is multiplied. """ - return self._coeff - - @property - def is_measurement(self) -> bool: - """ Whether the StateFn object is a measurement Operator. """ - return self._is_measurement - - def primitive_strings(self) -> Set[str]: - raise NotImplementedError - - @property - def num_qubits(self) -> int: - raise NotImplementedError - - def add(self, other: OperatorBase) -> OperatorBase: - raise NotImplementedError - - def adjoint(self) -> OperatorBase: - raise NotImplementedError - - def _expand_dim(self, num_qubits: int) -> 'StateFn': - raise NotImplementedError - - def permute(self, permutation: List[int]) -> OperatorBase: - """Permute the qubits of the state function. - - Args: - permutation: A list defining where each qubit should be permuted. The qubit at index - j of the circuit should be permuted to position permutation[j]. - - Returns: - A new StateFn containing the permuted primitive. - """ - raise NotImplementedError - - def equals(self, other: OperatorBase) -> bool: - if not isinstance(other, type(self)) or not self.coeff == other.coeff: - return False - - return self.primitive == other.primitive - # Will return NotImplementedError if not supported - - def mul(self, scalar: Union[int, float, complex, ParameterExpression]) -> OperatorBase: - if not isinstance(scalar, (int, float, complex, ParameterExpression)): - raise ValueError('Operators can only be scalar multiplied by float or complex, not ' - '{} of type {}.'.format(scalar, type(scalar))) - - return self.__class__(self.primitive, - coeff=self.coeff * scalar, - is_measurement=self.is_measurement) - - def tensor(self, other: OperatorBase) -> OperatorBase: - r""" - Return tensor product between self and other, overloaded by ``^``. - Note: You must be conscious of Qiskit's big-endian bit printing - convention. Meaning, Plus.tensor(Zero) - produces a \|+⟩ on qubit 0 and a \|0⟩ on qubit 1, or \|+⟩⨂\|0⟩, but - would produce a QuantumCircuit like - - \|0⟩-- - \|+⟩-- - - Because Terra prints circuits and results with qubit 0 - at the end of the string or circuit. - - Args: - other: The ``OperatorBase`` to tensor product with self. - - Returns: - An ``OperatorBase`` equivalent to the tensor product of self and other. - """ - raise NotImplementedError - - def tensorpower(self, other: int) -> Union[OperatorBase, int]: - if not isinstance(other, int) or other <= 0: - raise TypeError('Tensorpower can only take positive int arguments') - temp = StateFn(self.primitive, - coeff=self.coeff, - is_measurement=self.is_measurement) # type: OperatorBase - for _ in range(other - 1): - temp = temp.tensor(self) - return temp - - def _expand_shorter_operator_and_permute(self, other: OperatorBase, - permutation: Optional[List[int]] = None) \ - -> Tuple[OperatorBase, OperatorBase]: - - from qiskit.aqua.operators import Zero - - if self == StateFn({'0': 1}, is_measurement=True): - # Zero is special - we'll expand it to the correct qubit number. - return StateFn('0' * other.num_qubits, is_measurement=True), other - elif other == Zero: - # Zero is special - we'll expand it to the correct qubit number. - return self, StateFn('0' * self.num_qubits) - - return super()._expand_shorter_operator_and_permute(other, permutation) - - def to_matrix(self, massive: bool = False) -> np.ndarray: - raise NotImplementedError - - def to_density_matrix(self, massive: bool = False) -> np.ndarray: - """ Return matrix representing product of StateFn evaluated on pairs of basis states. - Overridden by child classes. - - Args: - massive: Whether to allow large conversions, e.g. creating a matrix representing - over 16 qubits. - - Returns: - The NumPy array representing the density matrix of the State function. - - Raises: - ValueError: If massive is set to False, and exponentially large computation is needed. - """ - raise NotImplementedError - - def compose(self, other: OperatorBase, - permutation: Optional[List[int]] = None, front: bool = False) -> OperatorBase: - r""" - Composition (Linear algebra-style: A@B(x) = A(B(x))) is not well defined for states - in the binary function model, but is well defined for measurements. - - Args: - other: The Operator to compose with self. - permutation: ``List[int]`` which defines permutation on other operator. - front: If front==True, return ``other.compose(self)``. - - Returns: - An Operator equivalent to the function composition of self and other. - - Raises: - ValueError: If self is not a measurement, it cannot be composed from the right. - """ - # TODO maybe allow outers later to produce density operators or projectors, but not yet. - if not self.is_measurement and not front: - raise ValueError( - 'Composition with a Statefunction in the first operand is not defined.') - - new_self, other = self._expand_shorter_operator_and_permute(other, permutation) - - if front: - return other.compose(self) - # TODO maybe include some reduction here in the subclasses - vector and Op, op and Op, etc. - # pylint: disable=import-outside-toplevel - from qiskit.aqua.operators import CircuitOp - - if self.primitive == {'0' * self.num_qubits: 1.0} and isinstance(other, CircuitOp): - # Returning CircuitStateFn - return StateFn(other.primitive, is_measurement=self.is_measurement, - coeff=self.coeff * other.coeff) - - from qiskit.aqua.operators import ComposedOp - return ComposedOp([new_self, other]) - - def power(self, exponent: int) -> OperatorBase: - """ Compose with Self Multiple Times, undefined for StateFns. - - Args: - exponent: The number of times to compose self with self. - - Raises: - ValueError: This function is not defined for StateFns. - """ - raise ValueError('Composition power over Statefunctions or Measurements is not defined.') - - def __str__(self) -> str: - prim_str = str(self.primitive) - if self.coeff == 1.0: - return "{}({})".format('StateFunction' if not self.is_measurement - else 'Measurement', self.coeff) - else: - return "{}({}) * {}".format('StateFunction' if not self.is_measurement - else 'Measurement', - self.coeff, - prim_str) - - def __repr__(self) -> str: - return "{}({}, coeff={}, is_measurement={})".format(self.__class__.__name__, - repr(self.primitive), - self.coeff, self.is_measurement) - - def eval(self, - front: Optional[Union[str, Dict[str, complex], np.ndarray, OperatorBase]] = None - ) -> Union[OperatorBase, float, complex]: - raise NotImplementedError - - @property - def parameters(self): - params = set() - if isinstance(self.primitive, (OperatorBase, QuantumCircuit)): - params.update(self.primitive.parameters) - if isinstance(self.coeff, ParameterExpression): - params.update(self.coeff.parameters) - return params - - def assign_parameters(self, param_dict: dict) -> OperatorBase: - param_value = self.coeff - if isinstance(self.coeff, ParameterExpression): - unrolled_dict = self._unroll_param_dict(param_dict) - if isinstance(unrolled_dict, list): - # pylint: disable=import-outside-toplevel - from ..list_ops.list_op import ListOp - return ListOp([self.assign_parameters(param_dict) for param_dict in unrolled_dict]) - if self.coeff.parameters <= set(unrolled_dict.keys()): - binds = {param: unrolled_dict[param] for param in self.coeff.parameters} - param_value = float(self.coeff.bind(binds)) - return self.traverse(lambda x: x.assign_parameters(param_dict), coeff=param_value) - - # Try collapsing primitives where possible. Nothing to collapse here. - def reduce(self) -> OperatorBase: - return self - - def traverse(self, - convert_fn: Callable, - coeff: Optional[Union[int, float, complex, ParameterExpression]] = None - ) -> OperatorBase: - r""" - Apply the convert_fn to the internal primitive if the primitive is an Operator (as in - the case of ``OperatorStateFn``). Otherwise do nothing. Used by converters. - - Args: - convert_fn: The function to apply to the internal OperatorBase. - coeff: A coefficient to multiply by after applying convert_fn. - If it is None, self.coeff is used instead. - - Returns: - The converted StateFn. - """ - if coeff is None: - coeff = self.coeff - - if isinstance(self.primitive, OperatorBase): - return StateFn(convert_fn(self.primitive), - coeff=coeff, is_measurement=self.is_measurement) - else: - return self - - def to_matrix_op(self, massive: bool = False) -> OperatorBase: - """ Return a ``VectorStateFn`` for this ``StateFn``. - - Args: - massive: Whether to allow large conversions, e.g. creating a matrix representing - over 16 qubits. - - Returns: - A VectorStateFn equivalent to self. - """ - # pylint: disable=cyclic-import,import-outside-toplevel - from .vector_state_fn import VectorStateFn - return VectorStateFn(self.to_matrix(massive=massive), is_measurement=self.is_measurement) - - def to_legacy_op(self, massive: bool = False) -> LegacyBaseOperator: - raise TypeError('A StateFn cannot be represented by LegacyBaseOperator.') - - # TODO to_dict_op - - def sample(self, - shots: int = 1024, - massive: bool = False, - reverse_endianness: bool = False) -> Dict[str, Union[int, float]]: - """ Sample the state function as a normalized probability distribution. Returns dict of - bitstrings in order of probability, with values being probability. - - Args: - shots: The number of samples to take to approximate the State function. - massive: Whether to allow large conversions, e.g. creating a matrix representing - over 16 qubits. - reverse_endianness: Whether to reverse the endianness of the bitstrings in the return - dict to match Terra's big-endianness. - - Returns: - A dict containing pairs sampled strings from the State function and sampling - frequency divided by shots. - """ - raise NotImplementedError + instance = OperatorStateFn.__new__(OperatorStateFn) + instance.__init__(primitive, coeff, is_measurement) + return instance + + raise TypeError( + "Unsupported primitive type {} passed into StateFn " + "factory constructor".format(type(primitive)) + ) diff --git a/qiskit/aqua/operators/state_fns/state_fn_base.py b/qiskit/aqua/operators/state_fns/state_fn_base.py new file mode 100644 index 0000000000..de5bab28cb --- /dev/null +++ b/qiskit/aqua/operators/state_fns/state_fn_base.py @@ -0,0 +1,399 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2020. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" StateFnBase Class """ + +from typing import Callable, Dict, List, Optional, Set, Tuple, Union + +import numpy as np +from qiskit.circuit import Instruction, ParameterExpression +from qiskit.quantum_info import Statevector +from qiskit.result import Result + +from qiskit import QuantumCircuit + +from ..legacy.base_operator import LegacyBaseOperator +from ..operator_base import OperatorBase +from .state_fn import StateFn + + +class StateFnBase(OperatorBase): + r""" + A base class for all StateFns: DictStateFn, ListStateFn, OperatorStateFn, VectorStateFn, etc. + + State functions are defined to be complex functions over a single binary string (as + compared to an operator, which is defined as a function over two binary strings, or a + function taking a binary function to another binary function). This function may be + called by the eval() method. + + Measurements are defined to be functionals over StateFns, taking them to real values. + Generally, this real value is interpreted to represent the probability of some classical + state (binary string) being observed from a probabilistic or quantum system represented + by a StateFn. This leads to the equivalent definition, which is that a measurement m is + a function over binary strings producing StateFns, such that the probability of measuring + a given binary string b from a system with StateFn f is equal to the inner + product between f and m(b). + + NOTE: State functions here are not restricted to wave functions, as there is + no requirement of normalization. + """ + + # TODO allow normalization somehow? + def __init__( + self, + primitive: Union[ + str, + dict, + Result, + list, + np.ndarray, + Statevector, + QuantumCircuit, + Instruction, + OperatorBase, + ] = None, + coeff: Union[int, float, complex, ParameterExpression] = 1.0, + is_measurement: bool = False, + ) -> None: + """ + Args: + primitive: The primitive which defines the behavior of the underlying State function. + coeff: A coefficient by which the state function is multiplied. + is_measurement: Whether the StateFn is a measurement operator + """ + self._primitive = primitive + self._is_measurement = is_measurement + self._coeff = coeff + + @property + def primitive(self): + """ The primitive which defines the behavior of the underlying State function. """ + return self._primitive + + @property + def coeff(self) -> Union[int, float, complex, ParameterExpression]: + """ A coefficient by which the state function is multiplied. """ + return self._coeff + + @property + def is_measurement(self) -> bool: + """ Whether the StateFn object is a measurement Operator. """ + return self._is_measurement + + def primitive_strings(self) -> Set[str]: + raise NotImplementedError + + @property + def num_qubits(self) -> int: + raise NotImplementedError + + def add(self, other: OperatorBase) -> OperatorBase: + raise NotImplementedError + + def adjoint(self) -> OperatorBase: + raise NotImplementedError + + def _expand_dim(self, num_qubits: int) -> 'StateFn': + raise NotImplementedError + + def permute(self, permutation: List[int]) -> OperatorBase: + """Permute the qubits of the state function. + + Args: + permutation: A list defining where each qubit should be permuted. The qubit at index + j of the circuit should be permuted to position permutation[j]. + + Returns: + A new StateFn containing the permuted primitive. + """ + raise NotImplementedError + + def equals(self, other: OperatorBase) -> bool: + if not isinstance(other, type(self)) or not self.coeff == other.coeff: + return False + + return self.primitive == other.primitive + # Will return NotImplementedError if not supported + + def mul( + self, scalar: Union[int, float, complex, ParameterExpression] + ) -> OperatorBase: + if not isinstance(scalar, (int, float, complex, ParameterExpression)): + raise ValueError( + "Operators can only be scalar multiplied by float or complex, not " + "{} of type {}.".format(scalar, type(scalar)) + ) + + return self.__class__( + self.primitive, + coeff=self.coeff * scalar, + is_measurement=self.is_measurement, + ) + + def tensor(self, other: OperatorBase) -> OperatorBase: + r""" + Return tensor product between self and other, overloaded by ``^``. + Note: You must be conscious of Qiskit's big-endian bit printing + convention. Meaning, Plus.tensor(Zero) + produces a \|+⟩ on qubit 0 and a \|0⟩ on qubit 1, or \|+⟩⨂\|0⟩, but + would produce a QuantumCircuit like + + \|0⟩-- + \|+⟩-- + + Because Terra prints circuits and results with qubit 0 + at the end of the string or circuit. + + Args: + other: The ``OperatorBase`` to tensor product with self. + + Returns: + An ``OperatorBase`` equivalent to the tensor product of self and other. + """ + raise NotImplementedError + + def tensorpower(self, other: int) -> Union[OperatorBase, int]: + if not isinstance(other, int) or other <= 0: + raise TypeError("Tensorpower can only take positive int arguments") + temp = StateFn( + self.primitive, coeff=self.coeff, is_measurement=self.is_measurement + ) # type: OperatorBase + for _ in range(other - 1): + temp = temp.tensor(self) + return temp + + def _expand_shorter_operator_and_permute( + self, other: OperatorBase, permutation: Optional[List[int]] = None + ) -> Tuple[OperatorBase, OperatorBase]: + + from qiskit.aqua.operators import Zero + + if self == StateFn({"0": 1}, is_measurement=True): + # Zero is special - we'll expand it to the correct qubit number. + return StateFn("0" * other.num_qubits, is_measurement=True), other + elif other == Zero: + # Zero is special - we'll expand it to the correct qubit number. + return self, StateFn("0" * self.num_qubits) + + return super()._expand_shorter_operator_and_permute(other, permutation) + + def to_matrix(self, massive: bool = False) -> np.ndarray: + raise NotImplementedError + + def to_density_matrix(self, massive: bool = False) -> np.ndarray: + """Return matrix representing product of StateFn evaluated on pairs of basis states. + Overridden by child classes. + + Args: + massive: Whether to allow large conversions, e.g. creating a matrix representing + over 16 qubits. + + Returns: + The NumPy array representing the density matrix of the State function. + + Raises: + ValueError: If massive is set to False, and exponentially large computation is needed. + """ + raise NotImplementedError + + def compose( + self, + other: OperatorBase, + permutation: Optional[List[int]] = None, + front: bool = False, + ) -> OperatorBase: + r""" + Composition (Linear algebra-style: A@B(x) = A(B(x))) is not well defined for states + in the binary function model, but is well defined for measurements. + + Args: + other: The Operator to compose with self. + permutation: ``List[int]`` which defines permutation on other operator. + front: If front==True, return ``other.compose(self)``. + + Returns: + An Operator equivalent to the function composition of self and other. + + Raises: + ValueError: If self is not a measurement, it cannot be composed from the right. + """ + # TODO maybe allow outers later to produce density operators or projectors, but not yet. + if not self.is_measurement and not front: + raise ValueError( + "Composition with a Statefunction in the first operand is not defined." + ) + + new_self, other = self._expand_shorter_operator_and_permute(other, permutation) + + if front: + return other.compose(self) + # TODO maybe include some reduction here in the subclasses - vector and Op, op and Op, etc. + # pylint: disable=import-outside-toplevel + from qiskit.aqua.operators import CircuitOp + + if self.primitive == {"0" * self.num_qubits: 1.0} and isinstance( + other, CircuitOp + ): + # Returning CircuitStateFn + return StateFn( + other.primitive, + is_measurement=self.is_measurement, + coeff=self.coeff * other.coeff, + ) + + from qiskit.aqua.operators import ComposedOp + + return ComposedOp([new_self, other]) + + def power(self, exponent: int) -> OperatorBase: + """Compose with Self Multiple Times, undefined for StateFns. + + Args: + exponent: The number of times to compose self with self. + + Raises: + ValueError: This function is not defined for StateFns. + """ + raise ValueError( + "Composition power over Statefunctions or Measurements is not defined." + ) + + def __str__(self) -> str: + prim_str = str(self.primitive) + if self.coeff == 1.0: + return "{}({})".format( + "StateFunction" if not self.is_measurement else "Measurement", + self.coeff, + ) + else: + return "{}({}) * {}".format( + "StateFunction" if not self.is_measurement else "Measurement", + self.coeff, + prim_str, + ) + + def __repr__(self) -> str: + return "{}({}, coeff={}, is_measurement={})".format( + self.__class__.__name__, + repr(self.primitive), + self.coeff, + self.is_measurement, + ) + + def eval( + self, + front: Optional[ + Union[str, Dict[str, complex], np.ndarray, OperatorBase] + ] = None, + ) -> Union[OperatorBase, float, complex]: + raise NotImplementedError + + @property + def parameters(self): + params = set() + if isinstance(self.primitive, (OperatorBase, QuantumCircuit)): + params.update(self.primitive.parameters) + if isinstance(self.coeff, ParameterExpression): + params.update(self.coeff.parameters) + return params + + def assign_parameters(self, param_dict: dict) -> OperatorBase: + param_value = self.coeff + if isinstance(self.coeff, ParameterExpression): + unrolled_dict = self._unroll_param_dict(param_dict) + if isinstance(unrolled_dict, list): + # pylint: disable=import-outside-toplevel + from ..list_ops.list_op import ListOp + + return ListOp( + [self.assign_parameters(param_dict) for param_dict in unrolled_dict] + ) + if self.coeff.parameters <= set(unrolled_dict.keys()): + binds = {param: unrolled_dict[param] for param in self.coeff.parameters} + param_value = float(self.coeff.bind(binds)) + return self.traverse( + lambda x: x.assign_parameters(param_dict), coeff=param_value + ) + + # Try collapsing primitives where possible. Nothing to collapse here. + def reduce(self) -> OperatorBase: + return self + + def traverse( + self, + convert_fn: Callable, + coeff: Optional[Union[int, float, complex, ParameterExpression]] = None, + ) -> OperatorBase: + r""" + Apply the convert_fn to the internal primitive if the primitive is an Operator (as in + the case of ``OperatorStateFn``). Otherwise do nothing. Used by converters. + + Args: + convert_fn: The function to apply to the internal OperatorBase. + coeff: A coefficient to multiply by after applying convert_fn. + If it is None, self.coeff is used instead. + + Returns: + The converted StateFn. + """ + if coeff is None: + coeff = self.coeff + + if isinstance(self.primitive, OperatorBase): + return StateFn( + convert_fn(self.primitive), + coeff=coeff, + is_measurement=self.is_measurement, + ) + else: + return self + + def to_matrix_op(self, massive: bool = False) -> OperatorBase: + """Return a ``VectorStateFn`` for this ``StateFn``. + + Args: + massive: Whether to allow large conversions, e.g. creating a matrix representing + over 16 qubits. + + Returns: + A VectorStateFn equivalent to self. + """ + # pylint: disable=cyclic-import,import-outside-toplevel + from .vector_state_fn import VectorStateFn + + return VectorStateFn( + self.to_matrix(massive=massive), is_measurement=self.is_measurement + ) + + def to_legacy_op(self, massive: bool = False) -> LegacyBaseOperator: + raise TypeError("A StateFn cannot be represented by LegacyBaseOperator.") + + # TODO to_dict_op + + def sample( + self, shots: int = 1024, massive: bool = False, reverse_endianness: bool = False + ) -> Dict[str, Union[int, float]]: + """Sample the state function as a normalized probability distribution. Returns dict of + bitstrings in order of probability, with values being probability. + + Args: + shots: The number of samples to take to approximate the State function. + massive: Whether to allow large conversions, e.g. creating a matrix representing + over 16 qubits. + reverse_endianness: Whether to reverse the endianness of the bitstrings in the return + dict to match Terra's big-endianness. + + Returns: + A dict containing pairs sampled strings from the State function and sampling + frequency divided by shots. + """ + raise NotImplementedError diff --git a/qiskit/aqua/operators/state_fns/vector_state_fn.py b/qiskit/aqua/operators/state_fns/vector_state_fn.py index 277fdccf65..6aeed79578 100644 --- a/qiskit/aqua/operators/state_fns/vector_state_fn.py +++ b/qiskit/aqua/operators/state_fns/vector_state_fn.py @@ -23,11 +23,12 @@ from ..operator_base import OperatorBase from .state_fn import StateFn +from .state_fn_base import StateFnBase from ..list_ops.list_op import ListOp from ...utils import arithmetic -class VectorStateFn(StateFn): +class VectorStateFn(StateFnBase): """ A class for state functions and measurements which are defined in vector representation, and stored using Terra's ``Statevector`` class. """