Skip to content

Commit

Permalink
modified embedding to allow time-dependent HamEvo; fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vytautas-a committed Sep 23, 2024
1 parent 3048b8a commit 671b0a0
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 37 deletions.
46 changes: 29 additions & 17 deletions qadence/backends/pyqtorch/convert_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
from functools import partial, reduce
from itertools import chain as flatten
from typing import Callable, Sequence
Expand Down Expand Up @@ -50,6 +51,12 @@
sympy.log: "log",
sympy.tan: "tan",
sympy.tanh: "tanh",
sympy.Heaviside: "heaviside",
sympy.Abs: "abs",
sympy.exp: "exp",
sympy.acos: "acos",
sympy.asin: "asin",
sympy.atan: "atan",
}


Expand Down Expand Up @@ -91,7 +98,7 @@ def extract_parameter(block: ScaleBlock | ParametricBlock, config: Configuration
return config.get_param_name(block)[0]


def sympy_to_pyq(expr: sympy.Expr) -> ConcretizedCallable:
def sympy_to_pyq(expr: sympy.Expr) -> ConcretizedCallable | Tensor:
"""Convert sympy expression to pyqtorch ConcretizedCallable object.
Args:
Expand All @@ -103,18 +110,14 @@ def sympy_to_pyq(expr: sympy.Expr) -> ConcretizedCallable:

# base case - independent argument
if len(expr.args) == 0:
res = (
float(expr)
if str(expr).replace(".", "", 1).replace("-", "", 1).isdigit()
else str(expr)
)
if isinstance(res, str):
try:
res = torch.as_tensor(float(expr))
except Exception as e:
res = str(expr)

if "/" in res:
# found a rational
res = float(sympy.Rational(res).evalf())
if "fix" in res: # type: ignore [operator]
# found a fixed parameter - convert to float
res = torch.as_tensor(float(res.split("_")[1])) # type: ignore [attr-defined]
res = torch.as_tensor(float(sympy.Rational(res).evalf()))
return res

# iterate through current function arguments
Expand Down Expand Up @@ -154,16 +157,25 @@ def convert_block(

if isinstance(block, ScaleBlock):
scaled_ops = convert_block(block.block, n_qubits, config)
scale = extract_parameter(block, config)
return [pyq.Scale(pyq.Sequence(scaled_ops), sympy_to_pyq(sympy.parse_expr(scale)))]
scale = extract_parameter(block, config=config)

# replace underscore by dot when underscore is between two numbers in string
if isinstance(scale, str):
scale = re.sub(r"(?<=\d)_(?=\d)", ".", scale)
if isinstance(scale, str) and not config._use_gate_params:
param = sympy_to_pyq(sympy.parse_expr(scale))
else:
param = scale

return [pyq.Scale(pyq.Sequence(scaled_ops), param)]

elif isinstance(block, TimeEvolutionBlock):
if getattr(block.generator, "is_time_dependent", False):
generator = convert_block(
block.generator, config=Configuration(_use_gate_params=False) # type: ignore [arg-type]
)[0]
config._use_gate_params = False
generator = convert_block(block.generator, config=config)[0] # type: ignore [arg-type]
elif isinstance(block.generator, sympy.Basic):
generator = config.get_param_name(block)[1]

elif isinstance(block.generator, Tensor):
m = block.generator.to(dtype=cdouble)
generator = convert_block(
Expand All @@ -177,6 +189,7 @@ def convert_block(
else:
generator = convert_block(block.generator, n_qubits, config)[0] # type: ignore[arg-type]
time_param = config.get_param_name(block)[0]

return [
pyq.HamiltonianEvolution(
qubit_support=qubit_support,
Expand Down Expand Up @@ -211,7 +224,6 @@ def convert_block(
if isinstance(block, U):
op = pyq_cls(qubit_support[0], *config.get_param_name(block))
else:
# param = sympy_to_pyq(sympy.parse_expr(extract_parameter(block, config)))
param = extract_parameter(block, config)
op = pyq_cls(qubit_support[0], param)
else:
Expand Down
16 changes: 13 additions & 3 deletions qadence/blocks/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sympy
from numpy import array as nparray
from numpy import cdouble as npcdouble
from torch import tensor
from torch import as_tensor, tensor

from qadence.blocks import (
AbstractBlock,
Expand Down Expand Up @@ -142,8 +142,18 @@ def embedding_fn(params: ParamDictType, inputs: ParamDictType) -> ParamDictType:
gate_lvl_params[uuid] = embedded_params[e]
return gate_lvl_params
else:
out = {stringify(k): v for k, v in embedded_params.items()}
out.update({"orig_param_values": inputs})
embedded_params.update(inputs)
for k, v in params.items():
if k not in embedded_params:
embedded_params[k] = v
out = {
stringify(k)
if not isinstance(k, str)
else k: as_tensor(v)[None]
if as_tensor(v).ndim == 0
else v
for k, v in embedded_params.items()
}
return out

params: ParamDictType
Expand Down
6 changes: 1 addition & 5 deletions qadence/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,10 @@ def _(
diff_mode = DiffMode.AD
bknd = backend_factory(backend, diff_mode=diff_mode, configuration=configuration)
conv = bknd.convert(circuit)
if backend == BackendName.PYQTORCH:
vals = values
else:
vals = conv.embedding_fn(conv.params, values)
with no_grad():
return bknd.run(
circuit=conv.circuit,
param_values=vals,
param_values=conv.embedding_fn(conv.params, values),
state=state,
endianness=endianness,
)
Expand Down
5 changes: 1 addition & 4 deletions qadence/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,7 @@ def run(
if values is None:
values = {}

if self._backend_name == BackendName.PYQTORCH:
params = values
else:
params = self.embedding_fn(self._params, values)
params = self.embedding_fn(self._params, values)

return self.backend.run(self._circuit, params, state=state, endianness=endianness)

Expand Down
1 change: 1 addition & 0 deletions qadence/operations/ham_evo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
)
ps = {"parameter": Parameter(parameter), **gen_exprs}
self.parameters = ParamMap(**ps)
self.time_param = parameter
self.generator = generator
self.duration = duration

Expand Down
9 changes: 5 additions & 4 deletions tests/analog/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,17 @@ def test_pulser_pyq_addressing(amp: float, det: float, spacing: float) -> None:
assert torch.allclose(expval_pulser, expval_pyq, atol=MIDDLE_ACCEPTANCE)


@pytest.mark.skip(reason="Heaviside function not differentiable in torch; decide how to fix this")
@pytest.mark.flaky(max_runs=5)
def test_addressing_training() -> None:
n_qubits = 3
f_value = torch.rand(1)

# define training parameters
w_amp = {i: f"w_amp{i}" for i in range(n_qubits)}
w_det = {i: f"w_det{i}" for i in range(n_qubits)}
amp = "amp"
det = "det"
w_amp = {i: f"w_ampl{i}" for i in range(n_qubits)}
w_det = {i: f"w_detun{i}" for i in range(n_qubits)}
amp = "ampl"
det = "detun"

# define pattern and device specs
pattern = AddressingPattern(
Expand Down
4 changes: 2 additions & 2 deletions tests/engines/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_embeddings() -> None:
"orig_param_values", {}
) # TODO: remove this when embedding system is updated

assert len(list(low_level_params.keys())) == 9
assert len(list(low_level_params.keys())) == 9 + len(inputs)

assert [v for k, v in low_level_params.items() if k.startswith("fix_")][0] == 0.5
assert torch.allclose(low_level_params["3*x"], 3 * inputs["x"])
Expand Down Expand Up @@ -186,7 +186,7 @@ def func(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
assert torch.autograd.gradgradcheck(func, (inputs_x, inputs_y, param_w))

assert torch.allclose(
finite_diff(lambda x: func(x, inputs_y, param_w), inputs_x.reshape(-1, 1), (0,)),
finite_diff(lambda x: func(x.squeeze(0), inputs_y, param_w), inputs_x.reshape(-1, 1), (0,)),
torch.autograd.grad(expval, inputs_x, torch.ones_like(expval), create_graph=True)[0],
)

Expand Down
6 changes: 4 additions & 2 deletions tests/qadence/test_quantum_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def test_quantum_model_parameters(parametric_circuit: QuantumCircuit) -> None:
embedded_params_ad.pop(
"orig_param_values", {}
) # TODO: remove this when embedding system is updated
assert len(embedded_params_ad) == 5
assert (
len(embedded_params_ad) == 5 + 1
) # adding one because original param x is included for PYQ + AD
assert len(embedded_params_psr) == 6


Expand All @@ -71,7 +73,7 @@ def test_quantum_model_duplicate_expr(duplicate_expression_circuit: QuantumCircu
embedded_params_ad.pop(
"orig_param_values", {}
) # TODO: remove this when embedding system is updated
assert len(embedded_params_ad) == 2
assert len(embedded_params_ad) == 2 + 4 # adding 4 because all original params are included
assert len(embedded_params_psr) == 8


Expand Down

0 comments on commit 671b0a0

Please sign in to comment.