Skip to content

Commit

Permalink
Name-ordered circuit parameters (#5759)
Browse files Browse the repository at this point in the history
* change unique list to MappingView

* fix test

* start adding tests

* handle anonymous parameter assigment

* fix bind_parameters

* fix compose

* add reno, remove todos

* fix docstring

* fix inconsistent argname

* fix renaming gone wrong

* changes from code review

* use deprecate_arguments
* add deprecation section in reno
* add more tests

* fix naming in reno

* make compose(front) more efficient

* update paramtables in compose w/o iterating over all gates

* properly rebuild parameterable upon compose

* use params instead of param_dict

* fix __new__

* fix sorting

* fix tests

* fix parameter order in test

* Apply suggestions from code review

Co-authored-by: Luciano Bello <bel@zurich.ibm.com>

* rename params->parameters, update reno

* add tests, remove access table by index

* use ListEqual over Equal

* rm unused islice import

* Rm star-swallowers from bind/assign_parameters

* store parent instance instead of just name

* fix name `params` -> `parameters`

* fix typo

* fix typo #2

* add test for pickling paramvecelements

Co-authored-by: Erick Winston <ewinston@us.ibm.com>
Co-authored-by: Luciano Bello <bel@zurich.ibm.com>
  • Loading branch information
3 people authored Mar 2, 2021
1 parent cb21109 commit 296a745
Show file tree
Hide file tree
Showing 8 changed files with 569 additions and 58 deletions.
30 changes: 17 additions & 13 deletions qiskit/circuit/library/n_local/n_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from qiskit.circuit.quantumregister import QuantumRegister
from qiskit.circuit import Instruction, Parameter, ParameterVector, ParameterExpression
from qiskit.circuit.parametertable import ParameterTable
from qiskit.utils.deprecation import deprecate_arguments

from ..blueprintcircuit import BlueprintCircuit

Expand Down Expand Up @@ -736,9 +737,12 @@ def add_layer(self,

return self

def assign_parameters(self, param_dict: Union[dict, List[float], List[Parameter],
@deprecate_arguments({'param_dict': 'parameters'})
def assign_parameters(self, parameters: Union[dict, List[float], List[Parameter],
ParameterVector],
inplace: bool = False) -> Optional[QuantumCircuit]:
inplace: bool = False,
param_dict: Optional[dict] = None # pylint: disable=unused-argument
) -> Optional[QuantumCircuit]:
"""Assign parameters to the n-local circuit.
This method also supports passing a list instead of a dictionary. If a list
Expand All @@ -756,31 +760,31 @@ def assign_parameters(self, param_dict: Union[dict, List[float], List[Parameter]
if self._data is None:
self._build()

if not isinstance(param_dict, dict):
if len(param_dict) != self.num_parameters:
if not isinstance(parameters, dict):
if len(parameters) != self.num_parameters:
raise AttributeError('If the parameters are provided as list, the size must match '
'the number of parameters ({}), but {} are given.'.format(
self.num_parameters, len(param_dict)
self.num_parameters, len(parameters)
))
unbound_params = [param for param in self._ordered_parameters if
isinstance(param, ParameterExpression)]
unbound_parameters = [param for param in self._ordered_parameters if
isinstance(param, ParameterExpression)]

# to get a sorted list of unique parameters, keep track of the already used parameters
# in a set and add the parameters to the unique list only if not existing in the set
used = set()
unbound_unique_params = []
for param in unbound_params:
unbound_unique_parameters = []
for param in unbound_parameters:
if param not in used:
unbound_unique_params.append(param)
unbound_unique_parameters.append(param)
used.add(param)

param_dict = dict(zip(unbound_unique_params, param_dict))
parameters = dict(zip(unbound_unique_parameters, parameters))

if inplace:
new = [param_dict.get(param, param) for param in self.ordered_parameters]
new = [parameters.get(param, param) for param in self.ordered_parameters]
self._ordered_parameters = new

return super().assign_parameters(param_dict, inplace=inplace)
return super().assign_parameters(parameters, inplace=inplace)

def _parameterize_block(self, block, param_iter=None, rep_num=None, block_num=None,
indices=None, params=None):
Expand Down
209 changes: 208 additions & 1 deletion qiskit/circuit/parametertable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
"""
Look-up table for variable parameters in QuantumCircuit.
"""
from collections.abc import MutableMapping
import warnings
import functools
from collections.abc import MutableMapping, MappingView

from .instruction import Instruction

Expand Down Expand Up @@ -79,3 +81,208 @@ def __len__(self):

def __repr__(self):
return 'ParameterTable({})'.format(repr(self._table))


def _deprecated_set_method():
def deprecate(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# warn only once
if not wrapper._warned:
warnings.warn(f'The ParameterView.{func.__name__} method is deprecated as of '
'Qiskit Terra 0.17.0 and will be removed no sooner than 3 months '
'after the release date. Circuit parameters are returned as View '
'object, not set. To use set methods you can explicitly cast to a '
'set.', DeprecationWarning, stacklevel=2)
wrapper._warned = True
return func(*args, **kwargs)
wrapper._warned = False
return wrapper
return deprecate


class ParameterView(MappingView):
"""Temporary class to transition from a set return-type to list.
Derives from a list but implements all set methods, but all set-methods emit deprecation
warnings.
"""

def __init__(self, iterable=None):
if iterable is not None:
self.data = list(iterable)
else:
self.data = []

super().__init__(self.data)

@_deprecated_set_method()
def add(self, x):
"""Add a new element."""
if x not in self.data:
self.data.append(x)

def copy(self):
"""Copy the ParameterView."""
return self.__class__(self.data.copy())

@_deprecated_set_method()
def difference(self, *s):
"""Get the difference between self and the input."""
return self.__sub__(s)

@_deprecated_set_method()
def difference_update(self, *s):
"""Get the difference between self and the input in-place."""
for element in self:
if element in s:
self.remove(element)

@_deprecated_set_method()
def discard(self, x):
"""Remove an element from self."""
if x in self:
self.remove(x)

@_deprecated_set_method()
def intersection(self, *x):
"""Get the intersection between self and the input."""
return self.__and__(x)

@_deprecated_set_method()
def intersection_update(self, *x):
"""Get the intersection between self and the input in-place."""
return self.__iand__(x)

def isdisjoint(self, x):
"""Check whether self and the input are disjoint."""
return not any(element in self for element in x)

@_deprecated_set_method()
def issubset(self, x):
"""Check whether self is a subset of the input."""
return self.__le__(x)

@_deprecated_set_method()
def issuperset(self, x):
"""Check whether self is a superset of the input."""
return self.__ge__(x)

@_deprecated_set_method()
def symmetric_difference(self, x):
"""Get the symmetric difference of self and the input."""
return self.__xor__(x)

@_deprecated_set_method()
def symmetric_difference_update(self, x):
"""Get the symmetric difference of self and the input in-place."""
backward = x.difference(self)
self.difference_update(x)
self.update(backward)

@_deprecated_set_method()
def union(self, *x):
"""Get the union of self and the input."""
return self.__or__(x)

@_deprecated_set_method()
def update(self, *x):
"""Update self with the input."""
for element in x:
self.add(element)

def remove(self, x):
"""Remove an existing element from the view."""
self.data.remove(x)

def __repr__(self):
"""Format the class as string."""
return f'ParameterView({self.data})'

def __getitem__(self, index):
"""Get items."""
return self.data[index]

def __and__(self, x):
"""Get the intersection between self and the input."""
inter = []
for element in self:
if element in x:
inter.append(element)

return self.__class__(inter)

def __rand__(self, x):
"""Get the intersection between self and the input."""
return self.__and__(x)

def __iand__(self, x):
"""Get the intersection between self and the input in-place."""
for element in self:
if element not in x:
self.remove(element)
return self

def __len__(self):
"""Get the length."""
return len(self.data)

def __or__(self, x):
"""Get the union of self and the input."""
return set(self) | set(x)

def __ior__(self, x):
"""Update self with the input."""
self.update(*x)
return self

def __sub__(self, x):
"""Get the difference between self and the input."""
return set(self) - set(x)

@_deprecated_set_method()
def __isub__(self, x):
"""Get the difference between self and the input in-place."""
return self.difference_update(*x)

def __xor__(self, x):
"""Get the symmetric difference between self and the input."""
return set(self) ^ set(x)

@_deprecated_set_method()
def __ixor__(self, x):
"""Get the symmetric difference between self and the input in-place."""
self.symmetric_difference_update(x)
return self

def __ne__(self, other):
return set(other) != set(self)

def __eq__(self, other):
return set(other) == set(self)

def __le__(self, x):
return all(element in x for element in self)

def __lt__(self, x):
if x != self:
return self <= x
return False

def __ge__(self, x):
return all(element in self for element in x)

def __gt__(self, x):
if x != self:
return self >= x
return False

def __iter__(self):
return iter(self.data)

def __contains__(self, x):
return x in self.data

__hash__: None # type: ignore
__rand__ = __and__
__ror__ = __or__
40 changes: 38 additions & 2 deletions qiskit/circuit/parametervector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,45 @@

"""Parameter Vector Class to simplify management of parameter lists."""

from uuid import uuid4

from .parameter import Parameter


class ParameterVectorElement(Parameter):
"""An element of a ParameterVector."""

def __new__(cls, vector, index, uuid=None): # pylint:disable=unused-argument
obj = object.__new__(cls)

if uuid is None:
obj._uuid = uuid4()
else:
obj._uuid = uuid

obj._hash = hash(obj._uuid)
return obj

def __getnewargs__(self):
return (self.vector, self.index, self._uuid)

def __init__(self, vector, index):
name = f'{vector.name}[{index}]'
super().__init__(name)
self._vector = vector
self._index = index

@property
def index(self):
"""Get the index of this element in the parent vector."""
return self._index

@property
def vector(self):
"""Get the parent vector instance."""
return self._vector


class ParameterVector:
"""ParameterVector class to quickly generate lists of parameters."""

Expand All @@ -23,7 +59,7 @@ def __init__(self, name, length=0):
self._params = []
self._size = length
for i in range(length):
self._params += [Parameter('{}[{}]'.format(self._name, i))]
self._params += [ParameterVectorElement(self, i)]

@property
def name(self):
Expand Down Expand Up @@ -69,5 +105,5 @@ def resize(self, length):
"""
if length > len(self._params):
for i in range(len(self._params), length):
self._params += [Parameter('{}[{}]'.format(self._name, i))]
self._params += [ParameterVectorElement(self, i)]
self._size = length
Loading

0 comments on commit 296a745

Please sign in to comment.