Skip to content
This repository has been archived by the owner on Dec 7, 2021. It is now read-only.

Commit

Permalink
separate StateFn to Factory and BaseClass
Browse files Browse the repository at this point in the history
  • Loading branch information
ikkoham committed Oct 28, 2020
1 parent c4d3113 commit 6b4cfcd
Show file tree
Hide file tree
Showing 10 changed files with 463 additions and 345 deletions.
3 changes: 2 additions & 1 deletion qiskit/aqua/operators/converters/pauli_basis_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..list_ops.list_op import ListOp
from ..list_ops.composed_op import ComposedOp
from ..state_fns.state_fn import StateFn
from ..state_fns.state_fn_base import StateFnBase
from ..operator_globals import H, S, I
from .converter_base import ConverterBase

Expand Down Expand Up @@ -134,7 +135,7 @@ def convert(self, operator: OperatorBase) -> OperatorBase:
if isinstance(operator, (Pauli, PrimitiveOp)):
cob_instr_op, dest_pauli_op = self.get_cob_circuit(operator)
return self._replacement_fn(cob_instr_op, dest_pauli_op) # type: ignore
if isinstance(operator, StateFn) and 'Pauli' in operator.primitive_strings():
if isinstance(operator, StateFnBase) and 'Pauli' in operator.primitive_strings():
# If the StateFn/Meas only contains a Pauli, use it directly.
if isinstance(operator.primitive, PrimitiveOp):
cob_instr_op, dest_pauli_op = self.get_cob_circuit(operator.primitive)
Expand Down
15 changes: 8 additions & 7 deletions qiskit/aqua/operators/gradients/circuit_gradients/lin_comb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from qiskit.aqua.operators.operator_globals import Z, I, One, Zero
from qiskit.aqua.operators.primitive_ops.primitive_op import PrimitiveOp
from qiskit.aqua.operators.state_fns import StateFn, CircuitStateFn, DictStateFn, VectorStateFn
from qiskit.aqua.operators.state_fns.state_fn_base import StateFnBase
from qiskit.circuit import Gate, Instruction, Qubit
from qiskit.circuit import (QuantumCircuit, QuantumRegister, ParameterVector,
ParameterExpression)
Expand Down Expand Up @@ -115,14 +116,14 @@ def _prepare_operator(self,

if isinstance(operator, ComposedOp):
# Get the measurement and the state operator
if not isinstance(operator[0], StateFn) or not operator[0].is_measurement:
if not isinstance(operator[0], StateFnBase) or not operator[0].is_measurement:
raise ValueError("The given operator does not correspond to an expectation value")
if not isinstance(operator[-1], StateFn) or operator[-1].is_measurement:
if not isinstance(operator[-1], StateFnBase) or operator[-1].is_measurement:
raise ValueError("The given operator does not correspond to an expectation value")
if operator[0].is_measurement:
if len(operator.oplist) == 2:
state_op = operator[1]
if not isinstance(state_op, StateFn):
if not isinstance(state_op, StateFnBase):
raise TypeError('The StateFn representing the quantum state could not be'
'extracted.')
if isinstance(params, (ParameterExpression, ParameterVector)) or \
Expand All @@ -144,7 +145,7 @@ def _prepare_operator(self,
else:
state_op = deepcopy(operator)
state_op.oplist.pop(0)
if not isinstance(state_op, StateFn):
if not isinstance(state_op, StateFnBase):
raise TypeError('The StateFn representing the quantum state could not be'
'extracted.')

Expand All @@ -169,7 +170,7 @@ def _prepare_operator(self,
return operator.traverse(partial(self._prepare_operator, params=params))
elif isinstance(operator, ListOp):
return operator.traverse(partial(self._prepare_operator, params=params))
elif isinstance(operator, StateFn):
elif isinstance(operator, StateFnBase):
if operator.is_measurement:
return operator.traverse(partial(self._prepare_operator, params=params))
else:
Expand All @@ -190,7 +191,7 @@ def _prepare_operator(self,
return operator

def _gradient_states(self,
state_op: StateFn,
state_op: StateFnBase,
meas_op: Optional[OperatorBase] = None,
target_params: Optional[
Union[ParameterExpression, ParameterVector,
Expand Down Expand Up @@ -334,7 +335,7 @@ def _gradient_states(self,
return op

def _hessian_states(self,
state_op: StateFn,
state_op: StateFnBase,
meas_op: Optional[OperatorBase] = None,
target_params: Optional[Union[Tuple[ParameterExpression,
ParameterExpression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from qiskit.aqua.operators import (OperatorBase, StateFn, Zero, One, CircuitStateFn,
CircuitOp)
from qiskit.aqua.operators import SummedOp, ListOp, ComposedOp, DictStateFn, VectorStateFn
from qiskit.aqua.operators.state_fns.state_fn_base import StateFnBase
from qiskit.aqua.operators.gradients.circuit_gradients.circuit_gradient \
import CircuitGradient
from qiskit.aqua.operators.gradients.derivative_base import DerivativeBase
Expand Down Expand Up @@ -210,7 +211,7 @@ def _parameter_shift(self,
if isinstance(operator, ComposedOp):
shifted_op = shift_constant * (pshift_op - mshift_op)

elif isinstance(operator, StateFn):
elif isinstance(operator, StateFnBase):
shifted_op = ListOp(
[pshift_op, mshift_op],
combo_fn=partial(self._prob_combo_fn, shift_constant=shift_constant))
Expand Down
3 changes: 2 additions & 1 deletion qiskit/aqua/operators/gradients/hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
from qiskit.aqua.aqua_globals import AquaError
from qiskit.aqua.operators import Zero, One, CircuitStateFn, StateFn
from qiskit.aqua.operators.state_fns.state_fn_base import StateFnBase
from qiskit.aqua.operators.expectations import PauliExpectation
from qiskit.aqua.operators.gradients.gradient import Gradient
from qiskit.aqua.operators.gradients.hessian_base import HessianBase
Expand Down Expand Up @@ -239,7 +240,7 @@ def is_coeff_c(coeff, c):

return SummedOp([term1, term2])

elif isinstance(operator, StateFn):
elif isinstance(operator, StateFnBase):
if not operator.is_measurement:
from .circuit_gradients import LinComb
if isinstance(self.hess_method, LinComb):
Expand Down
3 changes: 2 additions & 1 deletion qiskit/aqua/operators/state_fns/circuit_state_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
from ..operator_base import OperatorBase
from ..list_ops.summed_op import SummedOp
from .state_fn import StateFn
from .state_fn_base import StateFnBase


class CircuitStateFn(StateFn):
class CircuitStateFn(StateFnBase):
r"""
A class for state functions and measurements which are defined by the action of a
QuantumCircuit starting from \|0⟩, and stored using Terra's ``QuantumCircuit`` class.
Expand Down
3 changes: 2 additions & 1 deletion qiskit/aqua/operators/state_fns/dict_state_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@

from ..operator_base import OperatorBase
from .state_fn import StateFn
from .state_fn_base import StateFnBase
from ..list_ops.list_op import ListOp
from .vector_state_fn import VectorStateFn


class DictStateFn(StateFn):
class DictStateFn(StateFnBase):
""" A class for state functions and measurements which are defined by a lookup table,
stored in a dict.
"""
Expand Down
3 changes: 2 additions & 1 deletion qiskit/aqua/operators/state_fns/operator_state_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

from ..operator_base import OperatorBase
from .state_fn import StateFn
from .state_fn_base import StateFnBase
from .vector_state_fn import VectorStateFn
from ..list_ops.list_op import ListOp
from ..list_ops.summed_op import SummedOp


# pylint: disable=invalid-name

class OperatorStateFn(StateFn):
class OperatorStateFn(StateFnBase):
r"""
A class for state functions and measurements which are defined by a density Operator,
stored using an ``OperatorBase``.
Expand Down
Loading

0 comments on commit 6b4cfcd

Please sign in to comment.