Skip to content

Commit 792dbd8

Browse files
authored
Add Task and Simulation utilities for Discrete Time Crystal Experiments (#250)
* Add Tasks for Discrete Time Crystal Experiments * Fix import and dataclass decorators * Full removal of tasks and dataclasses * Run `black` python formatter * Improve docstring attribute descriptions * fix 2021 copyright to 2022 * Add docstring newlines, add Iterator import * Updated test cases * Fix parameter descriptions * Add DTC simulation utilities for for experiments * Copyright addition, add type hints * Fix imports broken in merge * Rename dtcsimulation -> dtc_simulation * Feedback Updates * re-run formatter * Feedback updates v1 * Add more unit level tests * Added another check for probabilities_predicate * Feedback updates
1 parent 750a259 commit 792dbd8

File tree

4 files changed

+546
-1
lines changed

4 files changed

+546
-1
lines changed

recirq/time_crystals/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,9 @@
1818
EXPERIMENT_NAME,
1919
DEFAULT_BASE_DIR,
2020
)
21+
22+
from recirq.time_crystals.dtc_simulation import (
23+
run_comparison_experiment,
24+
signal_ratio,
25+
symbolic_dtc_circuit_list,
26+
)
Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
# Copyright 2022 Google
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Sequence, Tuple, List, Iterator
16+
17+
import functools
18+
import numpy as np
19+
import sympy as sp
20+
21+
import cirq
22+
from recirq.time_crystals.dtcexperiment import DTCExperiment, comparison_experiments
23+
24+
25+
def symbolic_dtc_circuit_list(
26+
qubits: Sequence[cirq.Qid], cycles: int
27+
) -> List[cirq.Circuit]:
28+
"""Create a list of symbolically parameterized dtc circuits, with increasing cycles
29+
30+
Args:
31+
qubits: ordered sequence of available qubits, which are connected in a chain
32+
cycles: maximum number of cycles to generate up to
33+
34+
Returns:
35+
list of circuits with `0, 1, 2, ... cycles` many cycles
36+
37+
"""
38+
num_qubits = len(qubits)
39+
40+
# Symbol for g
41+
g_value = sp.Symbol("g")
42+
43+
# Symbols for random variance (h) and initial state, one per qubit
44+
local_fields = sp.symbols(f"local_field_:{num_qubits}")
45+
initial_state = sp.symbols(f"initial_state_:{num_qubits}")
46+
47+
# Symbols used for PhasedFsimGate, one for every qubit pair in the chain
48+
thetas = sp.symbols(f"theta_:{num_qubits - 1}")
49+
zetas = sp.symbols(f"zeta_:{num_qubits - 1}")
50+
chis = sp.symbols(f"chi_:{num_qubits - 1}")
51+
gammas = sp.symbols(f"gamma_:{num_qubits - 1}")
52+
phis = sp.symbols(f"phi_:{num_qubits - 1}")
53+
54+
# Initial moment of Y gates, conditioned on initial state
55+
initial_operations = cirq.Moment(
56+
[cirq.Y(qubit) ** initial_state[index] for index, qubit in enumerate(qubits)]
57+
)
58+
59+
# First component of U cycle, a moment of XZ gates.
60+
sequence_operations = []
61+
for index, qubit in enumerate(qubits):
62+
sequence_operations.append(
63+
cirq.PhasedXZGate(
64+
x_exponent=g_value,
65+
axis_phase_exponent=0.0,
66+
z_exponent=local_fields[index],
67+
)(qubit)
68+
)
69+
70+
# Initialize U cycle
71+
u_cycle = [cirq.Moment(sequence_operations)]
72+
73+
# Second and third components of U cycle, a chain of 2-qubit PhasedFSim gates
74+
# The first component is all the 2-qubit PhasedFSim gates starting on even qubits
75+
# The second component is the 2-qubit gates starting on odd qubits
76+
even_qubit_moment = []
77+
odd_qubit_moment = []
78+
for index, (qubit, next_qubit) in enumerate(zip(qubits, qubits[1:])):
79+
# Add an fsim gate
80+
coupling_gate = cirq.ops.PhasedFSimGate(
81+
theta=thetas[index],
82+
zeta=zetas[index],
83+
chi=chis[index],
84+
gamma=gammas[index],
85+
phi=phis[index],
86+
)
87+
88+
if index % 2:
89+
even_qubit_moment.append(coupling_gate.on(qubit, next_qubit))
90+
else:
91+
odd_qubit_moment.append(coupling_gate.on(qubit, next_qubit))
92+
93+
# Add the two components into the U cycle
94+
u_cycle.append(cirq.Moment(even_qubit_moment))
95+
u_cycle.append(cirq.Moment(odd_qubit_moment))
96+
97+
# Prepare a list of circuits, with n=0,1,2,3 ... cycles many cycles
98+
circuit_list = []
99+
total_circuit = cirq.Circuit(initial_operations)
100+
circuit_list.append(total_circuit.copy())
101+
for _ in range(cycles):
102+
for moment in u_cycle:
103+
total_circuit.append(moment)
104+
circuit_list.append(total_circuit.copy())
105+
106+
return circuit_list
107+
108+
109+
def simulate_dtc_circuit_list(
110+
circuit_list: Sequence[cirq.Circuit],
111+
param_resolver: cirq.ParamResolver,
112+
qubit_order: Sequence[cirq.Qid],
113+
simulator: "cirq.SimulatesIntermediateState" = None,
114+
) -> np.ndarray:
115+
"""Simulate a dtc circuit list for a particular param_resolver
116+
117+
Utilizes the fact that simulating the last circuit in the list also
118+
simulates each previous circuit along the way
119+
120+
Args:
121+
circuit_list: DTC circuit list; each element is a circuit with
122+
increasingly many cycles
123+
param_resolver: `cirq.ParamResolver` to resolve symbolic parameters
124+
qubit_order: ordered sequence of qubits connected in a chain
125+
simulator: Optional simulator object which must
126+
support the `simulate_moment_steps` function
127+
128+
Returns:
129+
`np.ndarray` of shape (len(circuit_list), 2**number of qubits) representing
130+
the probability of measuring each bit string, for each circuit in the list
131+
132+
"""
133+
# prepare simulator
134+
if simulator is None:
135+
simulator = cirq.Simulator()
136+
137+
# record lengths of circuits in list
138+
if not all(len(x) < len(y) for x, y in zip(circuit_list, circuit_list[1:])):
139+
raise ValueError("circuits in circuit_list are not in increasing order of size")
140+
circuit_positions = {len(c) - 1 for c in circuit_list}
141+
142+
# only simulate one circuit, the last one
143+
circuit = circuit_list[-1]
144+
145+
# use simulate_moment_steps to recover all of the state vectors necessary,
146+
# while only simulating the circuit list once
147+
probabilities = []
148+
for k, step in enumerate(
149+
simulator.simulate_moment_steps(
150+
circuit=circuit, param_resolver=param_resolver, qubit_order=qubit_order
151+
)
152+
):
153+
# add the state vector if the number of moments simulated so far is equal
154+
# to the length of a circuit in the circuit_list
155+
if k in circuit_positions:
156+
probabilities.append(np.abs(step.state_vector()) ** 2)
157+
158+
return np.asarray(probabilities)
159+
160+
161+
def simulate_dtc_circuit_list_sweep(
162+
circuit_list: Sequence[cirq.Circuit],
163+
param_resolvers: Sequence[cirq.ParamResolver],
164+
qubit_order: Sequence[cirq.Qid],
165+
) -> Iterator[np.ndarray]:
166+
"""Simulate a dtc circuit list over a sweep of param_resolvers
167+
168+
Args:
169+
circuit_list: DTC circuit list; each element is a circuit with
170+
increasingly many cycles
171+
param_resolvers: list of `cirq.ParamResolver`s to sweep over
172+
qubit_order: ordered sequence of qubits connected in a chain
173+
174+
Yields:
175+
for each param_resolver, `np.ndarray`s of shape
176+
(len(circuit_list), 2**number of qubits) representing the probability
177+
of measuring each bit string, for each circuit in the list
178+
179+
"""
180+
# iterate over param resolvers and simulate for each
181+
for param_resolver in param_resolvers:
182+
yield simulate_dtc_circuit_list(circuit_list, param_resolver, qubit_order)
183+
184+
185+
def get_polarizations(
186+
probabilities: np.ndarray,
187+
num_qubits: int,
188+
initial_states: np.ndarray = None,
189+
) -> np.ndarray:
190+
"""Get polarizations from matrix of probabilities, possibly autocorrelated on
191+
the initial state.
192+
193+
A polarization is the marginal probability for a qubit to measure zero or one,
194+
over all possible basis states, scaled to the range [-1. 1].
195+
196+
Args:
197+
probabilities: `np.ndarray` of shape (:, cycles, 2**qubits)
198+
representing probability to measure each bit string
199+
num_qubits: the number of qubits in the circuit the probabilities
200+
were generated from
201+
initial_states: `np.ndarray` of shape (:, qubits) representing the initial
202+
state for each dtc circuit list
203+
204+
Returns:
205+
`np.ndarray` of shape (:, cycles, qubits) that represents each
206+
qubit's polarization
207+
208+
"""
209+
# prepare list of polarizations for each qubit
210+
polarizations = []
211+
for qubit_index in range(num_qubits):
212+
# select all indices in range(2**num_qubits) for which the
213+
# associated element of the statevector has qubit_index as zero
214+
shift_by = num_qubits - qubit_index - 1
215+
state_vector_indices = [
216+
i for i in range(2**num_qubits) if not (i >> shift_by) % 2
217+
]
218+
219+
# sum over all probabilities for qubit states for which qubit_index is zero,
220+
# and rescale them to [-1,1]
221+
polarization = (
222+
2.0
223+
* np.sum(
224+
probabilities.take(indices=state_vector_indices, axis=-1),
225+
axis=-1,
226+
)
227+
- 1.0
228+
)
229+
polarizations.append(polarization)
230+
231+
# turn polarizations list into an array,
232+
# and move the new, leftmost axis for qubits to the end
233+
polarizations = np.moveaxis(np.asarray(polarizations), 0, -1)
234+
235+
# flip polarizations according to the associated initial_state, if provided
236+
# this means that the polarization of a qubit is relative to it's initial state
237+
if initial_states is not None:
238+
initial_states = 1 - 2.0 * initial_states
239+
polarizations = initial_states * polarizations
240+
241+
return polarizations
242+
243+
244+
def signal_ratio(zeta_1: np.ndarray, zeta_2: np.ndarray) -> np.ndarray:
245+
"""Calculate signal ratio between two signals
246+
247+
Signal ratio measures how different two signals are,
248+
proportional to how large they are.
249+
250+
Args:
251+
zeta_1: signal (`np.ndarray` to represent polarization over time)
252+
zeta 2: signal (`np.ndarray` to represent polarization over time)
253+
254+
Returns:
255+
computed ratio of the signals zeta_1 and zeta_2 (`np.ndarray`)
256+
to represent polarization over time)
257+
258+
"""
259+
return np.abs(zeta_1 - zeta_2) / (np.abs(zeta_1) + np.abs(zeta_2))
260+
261+
262+
def simulate_for_polarizations(
263+
dtcexperiment: DTCExperiment,
264+
circuit_list: Sequence[cirq.Circuit],
265+
autocorrelate: bool = True,
266+
take_abs: bool = False,
267+
) -> np.ndarray:
268+
"""Simulate and get polarizations for a single DTCExperiment and circuit list
269+
270+
Args:
271+
dtcexperiment: DTCExperiment noting the parameters to simulate over some
272+
number of disorder instances
273+
circuit_list: symbolic dtc circuit list
274+
autocorrelate: whether or not to autocorrelate the polarizations with their
275+
respective initial states
276+
take_abs: whether or not to take the absolute value of the polarizations
277+
278+
Returns:
279+
simulated polarizations (np.ndarray of shape (num_cycles, num_qubits)) from
280+
the experiment, averaged over disorder instances
281+
282+
"""
283+
# create param resolver sweep
284+
param_resolvers = dtcexperiment.param_resolvers()
285+
286+
# prepare simulation generator
287+
probabilities_generator = simulate_dtc_circuit_list_sweep(
288+
circuit_list, param_resolvers, dtcexperiment.qubits
289+
)
290+
291+
# map get_polarizations over probabilities_generator
292+
polarizations_generator = map(
293+
lambda probabilities, initial_state: get_polarizations(
294+
probabilities,
295+
num_qubits=len(dtcexperiment.qubits),
296+
initial_states=(initial_state if autocorrelate else None),
297+
),
298+
probabilities_generator,
299+
dtcexperiment.initial_states,
300+
)
301+
302+
# take sum of (absolute value of) polarizations over different disorder instances
303+
polarization_sum = functools.reduce(
304+
lambda x, y: x + (np.abs(y) if take_abs else y),
305+
polarizations_generator,
306+
np.zeros((len(circuit_list), len(dtcexperiment.qubits))),
307+
)
308+
309+
# get average over disorder instances
310+
disorder_averaged_polarizations = (
311+
polarization_sum / dtcexperiment.disorder_instances
312+
)
313+
314+
return disorder_averaged_polarizations
315+
316+
317+
def run_comparison_experiment(
318+
qubits: Sequence[cirq.Qid],
319+
cycles: int,
320+
disorder_instances: int,
321+
autocorrelate: bool = True,
322+
take_abs: bool = False,
323+
**kwargs,
324+
) -> Iterator[np.ndarray]:
325+
"""Run multiple DTC experiments for qubit polarizations over different parameters.
326+
327+
This uses the default parameter options noted in
328+
`dtcexperiment.comparison_experiments` for any parameter not supplied in
329+
kwargs. A DTC experiment is then created and simulated for each possible
330+
parameter combination before qubit polarizations by DTC cycle are
331+
computed and yielded. Each yield is an `np.ndarray` of shape (qubits, cycles)
332+
for a specific combination of parameters.
333+
334+
Args:
335+
qubits: ordered sequence of available qubits, which are connected in a chain
336+
cycles: maximum number of cycles to generate up to
337+
autocorrelate: whether or not to autocorrelate the polarizations with their
338+
respective initial states
339+
take_abs: whether or not to take the absolute value of the polarizations
340+
kwargs: lists of non-default argument configurations to pass through
341+
to `dtcexperiment.comparison_experiments`
342+
343+
Yields:
344+
disorder averaged polarizations, ordered by
345+
`dtcexperiment.comparison_experiments`, with all other parameters default
346+
347+
"""
348+
circuit_list = symbolic_dtc_circuit_list(qubits, cycles)
349+
for dtcexperiment in comparison_experiments(
350+
qubits=qubits, disorder_instances=disorder_instances, **kwargs
351+
):
352+
yield simulate_for_polarizations(
353+
dtcexperiment=dtcexperiment,
354+
circuit_list=circuit_list,
355+
autocorrelate=autocorrelate,
356+
take_abs=take_abs,
357+
)

0 commit comments

Comments
 (0)