Skip to content

Commit

Permalink
Add Qiskit native QPY ParameterExpression serialization
Browse files Browse the repository at this point in the history
With the release of symengine 0.13.0 we discovered a version dependence
on the payload format used for serializing symengine expressions. This
was worked around in Qiskit#13251 but this is not a sustainable solution and
only works for symengine 0.11.0 and 0.13.0 (there was no 0.12.0). While
there was always the option to use sympy to serialize the underlying
symbolic expression (there is a `use_symengine` flag on `qpy.dumps` you
can set to `False` to do this) the sympy serialzation has several
tradeoffs most importantly is much higher runtime overhead. To solve
the issue moving forward a qiskit native representation of the parameter
expression object is necessary for serialization.

This commit bumps the QPY format version to 13 and adds a new
serialization format for ParameterExpression objects. This new format
is a serialization of the API calls made to ParameterExpression that
resulted in the creation of the underlying object. To facilitate this
the ParameterExpression class is expanded to store an internal "replay"
record of the API calls used to construct the ParameterExpression
object. This internal list is what gets serialized by QPY and then on
deserialization the "replay" is replayed to reconstruct the expression
object. This is a different approach to the previous QPY representations
of the ParameterExpression objects which instead represented the internal
state stored in the ParameterExpression object with the symbolic
expression from symengine (or a sympy copy of the expression). Doing
this directly in Qiskit isn't viable though because symengine's internal
expression tree is not exposed to Python directly. There isn't any
method (private or public) to walk the expression tree to construct
a serialization format based off of it. Converting symengine to a sympy
expression and then using sympy's API to walk the expression tree is a
possibility but that would tie us to sympy which would be problematic
for Qiskit#13267 and Qiskit#13131, have significant runtime overhead, and it would
be just easier to rely on sympy's native serialization tools.

The tradeoff with this approach is that it does increase the memory
overhead of the `ParameterExpression` class because for each element
in the expression we have to store a record of it. Depending on the
depth of the expression tree this also could be a lot larger than
symengine's internal representation as we store the raw api calls made
to create the ParameterExpression but symengine is likely simplifying
it's internal representation as it builds it out. But I personally think
this tradeoff is worthwhile as it ties the serialization format to the
Qiskit objects instead of relying on a 3rd party library. This also
gives us the flexibility of changing the internal symbolic expression
library internally in the future if we decide to stop using symengine
at any point.

Fixes Qiskit#13252
  • Loading branch information
mtreinish committed Oct 22, 2024
1 parent 9a1d8d3 commit 87e2a93
Show file tree
Hide file tree
Showing 5 changed files with 483 additions and 43 deletions.
2 changes: 2 additions & 0 deletions qiskit/circuit/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
self._hash = hash((self._parameter_keys, self._symbol_expr))
self._parameter_symbols = {self: symbol}
self._name_map = None
self._qpy_replay = []

def assign(self, parameter, value):
if parameter != self:
Expand Down Expand Up @@ -172,3 +173,4 @@ def __setstate__(self, state):
self._hash = hash((self._parameter_keys, self._symbol_expr))
self._parameter_symbols = {self: self._symbol_expr}
self._name_map = None
self._qpy_replay = []
166 changes: 133 additions & 33 deletions qiskit/circuit/parameterexpression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
"""

from __future__ import annotations

from dataclasses import dataclass
from enum import IntEnum
from typing import Callable, Union

import numbers
Expand All @@ -30,12 +33,79 @@
ParameterValueType = Union["ParameterExpression", float]


class _OPCode(IntEnum):
ADD = 0
SUB = 1
MUL = 2
DIV = 3
POW = 4
SIN = 5
COS = 6
TAN = 7
ASIN = 8
ACOS = 9
EXP = 10
LOG = 11
SIGN = 12
DERIV = 13
CONJ = 14
SUBSTITUTE = 15
ABS = 16
ATAN = 17


_OP_CODE_MAP = (
"__add__",
"__sub__",
"__mul__",
"__truediv__",
"__pow__",
"sin",
"cos",
"tan",
"arcsin",
"arccos",
"exp",
"log",
"sign",
"gradient",
"conjugate",
"subs",
"abs",
"arctan",
)


def op_code_to_method(op_code: _OPCode):
"""Return the method name for a given op_code."""
return _OP_CODE_MAP[op_code]


@dataclass
class _INSTRUCTION:
op: _OPCode
lhs: ParameterValueType
rhs: ParameterValueType | None = None


@dataclass
class _SUBS:
binds: dict
op: _OPCode = _OPCode.SUBSTITUTE


class ParameterExpression:
"""ParameterExpression class to enable creating expressions of Parameters."""

__slots__ = ["_parameter_symbols", "_parameter_keys", "_symbol_expr", "_name_map"]
__slots__ = [
"_parameter_symbols",
"_parameter_keys",
"_symbol_expr",
"_name_map",
"_qpy_replay",
]

def __init__(self, symbol_map: dict, expr):
def __init__(self, symbol_map: dict, expr, *, _qpy_replay=None):
"""Create a new :class:`ParameterExpression`.
Not intended to be called directly, but to be instantiated via operations
Expand All @@ -54,6 +124,10 @@ def __init__(self, symbol_map: dict, expr):
self._parameter_keys = frozenset(p._hash_key() for p in self._parameter_symbols)
self._symbol_expr = expr
self._name_map: dict | None = None
if _qpy_replay is not None:
self._qpy_replay = _qpy_replay
else:
self._qpy_replay = []

@property
def parameters(self) -> set:
Expand All @@ -69,8 +143,11 @@ def _names(self) -> dict:

def conjugate(self) -> "ParameterExpression":
"""Return the conjugate."""
new_op = _INSTRUCTION(_OPCode.CONJ, self)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)
conjugated = ParameterExpression(
self._parameter_symbols, symengine.conjugate(self._symbol_expr)
self._parameter_symbols, symengine.conjugate(self._symbol_expr), _qpy_replay=new_replay
)
return conjugated

Expand Down Expand Up @@ -117,6 +194,7 @@ def bind(
self._raise_if_passed_unknown_parameters(parameter_values.keys())
self._raise_if_passed_nan(parameter_values)

new_op = _SUBS(parameter_values)
symbol_values = {}
for parameter, value in parameter_values.items():
if (param_expr := self._parameter_symbols.get(parameter)) is not None:
Expand All @@ -143,7 +221,12 @@ def bind(
f"(Expression: {self}, Bindings: {parameter_values})."
)

return ParameterExpression(free_parameter_symbols, bound_symbol_expr)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

return ParameterExpression(
free_parameter_symbols, bound_symbol_expr, _qpy_replay=new_replay
)

def subs(
self, parameter_map: dict, allow_unknown_parameters: bool = False
Expand Down Expand Up @@ -175,6 +258,7 @@ def subs(
for p in replacement_expr.parameters
}
self._raise_if_parameter_names_conflict(inbound_names, parameter_map.keys())
new_op = _SUBS(parameter_map)

# Include existing parameters in self not set to be replaced.
new_parameter_symbols = {
Expand All @@ -192,8 +276,12 @@ def subs(
new_parameter_symbols[p] = symbol_type(p.name)

substituted_symbol_expr = self._symbol_expr.subs(symbol_map)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

return ParameterExpression(new_parameter_symbols, substituted_symbol_expr)
return ParameterExpression(
new_parameter_symbols, substituted_symbol_expr, _qpy_replay=new_replay
)

def _raise_if_passed_unknown_parameters(self, parameters):
unknown_parameters = parameters - self.parameters
Expand Down Expand Up @@ -231,7 +319,11 @@ def _raise_if_parameter_names_conflict(self, inbound_parameters, outbound_parame
)

def _apply_operation(
self, operation: Callable, other: ParameterValueType, reflected: bool = False
self,
operation: Callable,
other: ParameterValueType,
reflected: bool = False,
op_code: _OPCode = None,
) -> "ParameterExpression":
"""Base method implementing math operations between Parameters and
either a constant or a second ParameterExpression.
Expand All @@ -253,7 +345,6 @@ def _apply_operation(
A new expression describing the result of the operation.
"""
self_expr = self._symbol_expr

if isinstance(other, ParameterExpression):
self._raise_if_parameter_names_conflict(other._names)
parameter_symbols = {**self._parameter_symbols, **other._parameter_symbols}
Expand All @@ -266,10 +357,14 @@ def _apply_operation(

if reflected:
expr = operation(other_expr, self_expr)
new_op = _INSTRUCTION(op_code, other, self)
else:
expr = operation(self_expr, other_expr)
new_op = _INSTRUCTION(op_code, self, other)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

out_expr = ParameterExpression(parameter_symbols, expr)
out_expr = ParameterExpression(parameter_symbols, expr, _qpy_replay=new_replay)
out_expr._name_map = self._names.copy()
if isinstance(other, ParameterExpression):
out_expr._names.update(other._names.copy())
Expand Down Expand Up @@ -313,81 +408,86 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
return float(expr_grad)

def __add__(self, other):
return self._apply_operation(operator.add, other)
return self._apply_operation(operator.add, other, op_code=_OPCode.ADD)

def __radd__(self, other):
return self._apply_operation(operator.add, other, reflected=True)
return self._apply_operation(operator.add, other, reflected=True, op_code=_OPCode.ADD)

def __sub__(self, other):
return self._apply_operation(operator.sub, other)
return self._apply_operation(operator.sub, other, op_code=_OPCode.SUB)

def __rsub__(self, other):
return self._apply_operation(operator.sub, other, reflected=True)
return self._apply_operation(operator.sub, other, reflected=True, op_code=_OPCode.SUB)

def __mul__(self, other):
return self._apply_operation(operator.mul, other)
return self._apply_operation(operator.mul, other, op_code=_OPCode.MUL)

def __pos__(self):
return self._apply_operation(operator.mul, 1)
return self._apply_operation(operator.mul, 1, op_code=_OPCode.MUL)

def __neg__(self):
return self._apply_operation(operator.mul, -1)
return self._apply_operation(operator.mul, -1, op_code=_OPCode.MUL)

def __rmul__(self, other):
return self._apply_operation(operator.mul, other, reflected=True)
return self._apply_operation(operator.mul, other, reflected=True, op_code=_OPCode.MUL)

def __truediv__(self, other):
if other == 0:
raise ZeroDivisionError("Division of a ParameterExpression by zero.")
return self._apply_operation(operator.truediv, other)
return self._apply_operation(operator.truediv, other, op_code=_OPCode.DIV)

def __rtruediv__(self, other):
return self._apply_operation(operator.truediv, other, reflected=True)
return self._apply_operation(operator.truediv, other, reflected=True, op_code=_OPCode.DIV)

def __pow__(self, other):
return self._apply_operation(pow, other)
return self._apply_operation(pow, other, op_code=_OPCode.POW)

def __rpow__(self, other):
return self._apply_operation(pow, other, reflected=True)

def _call(self, ufunc):
return ParameterExpression(self._parameter_symbols, ufunc(self._symbol_expr))
return self._apply_operation(pow, other, reflected=True, op_code=_OPCode.POW)

def _call(self, ufunc, op_code):
new_op = _INSTRUCTION(op_code, self)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)
return ParameterExpression(
self._parameter_symbols, ufunc(self._symbol_expr), _qpy_replay=new_replay
)

def sin(self):
"""Sine of a ParameterExpression"""
return self._call(symengine.sin)
return self._call(symengine.sin, op_code=_OPCode.SIN)

def cos(self):
"""Cosine of a ParameterExpression"""
return self._call(symengine.cos)
return self._call(symengine.cos, op_code=_OPCode.COS)

def tan(self):
"""Tangent of a ParameterExpression"""
return self._call(symengine.tan)
return self._call(symengine.tan, op_code=_OPCode.TAN)

def arcsin(self):
"""Arcsin of a ParameterExpression"""
return self._call(symengine.asin)
return self._call(symengine.asin, op_code=_OPCode.ASIN)

def arccos(self):
"""Arccos of a ParameterExpression"""
return self._call(symengine.acos)
return self._call(symengine.acos, op_code=_OPCode.ACOS)

def arctan(self):
"""Arctan of a ParameterExpression"""
return self._call(symengine.atan)
return self._call(symengine.atan, op_code=_OPCode.ATAN)

def exp(self):
"""Exponential of a ParameterExpression"""
return self._call(symengine.exp)
return self._call(symengine.exp, op_code=_OPCode.EXP)

def log(self):
"""Logarithm of a ParameterExpression"""
return self._call(symengine.log)
return self._call(symengine.log, op_code=_OPCode.LOG)

def sign(self):
"""Sign of a ParameterExpression"""
return self._call(symengine.sign)
return self._call(symengine.sign, op_code=_OPCode.SIGN)

def __repr__(self):
return f"{self.__class__.__name__}({str(self)})"
Expand Down Expand Up @@ -455,7 +555,7 @@ def __deepcopy__(self, memo=None):

def __abs__(self):
"""Absolute of a ParameterExpression"""
return self._call(symengine.Abs)
return self._call(symengine.Abs, _OPCode.ABS)

def abs(self):
"""Absolute of a ParameterExpression"""
Expand Down
Loading

0 comments on commit 87e2a93

Please sign in to comment.