Skip to content

Commit

Permalink
Add readout plotting tools (#6425)
Browse files Browse the repository at this point in the history
  • Loading branch information
eliottrosenberg authored Jan 26, 2024
1 parent ee56c59 commit 2ef1909
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 1 deletion.
122 changes: 122 additions & 0 deletions cirq-core/cirq/experiments/single_qubit_readout_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

import sympy
import numpy as np
import matplotlib.pyplot as plt
import cirq.vis.heatmap as cirq_heatmap
import cirq.vis.histogram as cirq_histogram
from cirq.devices import grid_qubit
from cirq import circuits, ops, study

if TYPE_CHECKING:
Expand Down Expand Up @@ -51,6 +55,124 @@ def _json_dict_(self) -> Dict[str, Any]:
'timestamp': self.timestamp,
}

def plot_heatmap(
self,
axs: Optional[tuple[plt.Axes, plt.Axes]] = None,
annotation_format: str = '0.1%',
**plot_kwargs: Any,
) -> tuple[plt.Axes, plt.Axes]:
"""Plot a heatmap of the readout errors. If qubits are not cirq.GridQubits, throws an error.
Args:
axs: a tuple of the plt.Axes to plot on. If not given, a new figure is created,
plotted on, and shown.
annotation_format: The format string for the numbers in the heatmap.
**plot_kwargs: Arguments to be passed to 'cirq.Heatmap.plot()'.
Returns:
The two plt.Axes containing the plot.
Raises:
ValueError: axs does not contain two plt.Axes
TypeError: qubits are not cirq.GridQubits
"""

if axs is None:
_, axs = plt.subplots(1, 2, dpi=200, facecolor='white', figsize=(12, 4))

else:
if (
not isinstance(axs, (tuple, list, np.ndarray))
or len(axs) != 2
or type(axs[0]) != plt.Axes
or type(axs[1]) != plt.Axes
): # pragma: no cover
raise ValueError('axs should be a length-2 tuple of plt.Axes') # pragma: no cover
for ax, title, data in zip(
axs,
['$|0\\rangle$ errors', '$|1\\rangle$ errors'],
[self.zero_state_errors, self.one_state_errors],
):
data_with_grid_qubit_keys = {}
for qubit in data:
if type(qubit) != grid_qubit.GridQubit:
raise TypeError(f'{qubit} must be of type cirq.GridQubit') # pragma: no cover
data_with_grid_qubit_keys[qubit] = data[qubit] # just for typecheck
_, _ = cirq_heatmap.Heatmap(data_with_grid_qubit_keys).plot(
ax, annotation_format=annotation_format, title=title, **plot_kwargs
)
return axs[0], axs[1]

def plot_integrated_histogram(
self,
ax: Optional[plt.Axes] = None,
cdf_on_x: bool = False,
axis_label: str = 'Readout error rate',
semilog: bool = True,
median_line: bool = True,
median_label: Optional[str] = 'median',
mean_line: bool = False,
mean_label: Optional[str] = 'mean',
show_zero: bool = False,
title: Optional[str] = None,
**kwargs,
) -> plt.Axes:
"""Plot the readout errors using cirq.integrated_histogram().
Args:
ax: The axis to plot on. If None, we generate one.
cdf_on_x: If True, flip the axes compared the above example.
axis_label: Label for x axis (y-axis if cdf_on_x is True).
semilog: If True, force the x-axis to be logarithmic.
median_line: If True, draw a vertical line on the median value.
median_label: If drawing median line, optional label for it.
mean_line: If True, draw a vertical line on the mean value.
mean_label: If drawing mean line, optional label for it.
title: Title of the plot. If None, we assign "N={len(data)}".
show_zero: If True, moves the step plot up by one unit by prepending 0
to the data.
**kwargs: Kwargs to forward to `ax.step()`. Some examples are
color: Color of the line.
linestyle: Linestyle to use for the plot.
lw: linewidth for integrated histogram.
ms: marker size for a histogram trace.
Returns:
The axis that was plotted on.
"""

ax = cirq_histogram.integrated_histogram(
data=self.zero_state_errors,
ax=ax,
cdf_on_x=cdf_on_x,
semilog=semilog,
median_line=median_line,
median_label=median_label,
mean_line=mean_line,
mean_label=mean_label,
show_zero=show_zero,
color='C0',
label='$|0\\rangle$ errors',
**kwargs,
)
ax = cirq_histogram.integrated_histogram(
data=self.one_state_errors,
ax=ax,
cdf_on_x=cdf_on_x,
axis_label=axis_label,
semilog=semilog,
median_line=median_line,
median_label=median_label,
mean_line=mean_line,
mean_label=mean_label,
show_zero=show_zero,
title=title,
color='C1',
label='$|1\\rangle$ errors',
**kwargs,
)
ax.legend(loc='best')
ax.set_ylabel('Percentile')
return ax

@classmethod
def _from_json_dict_(
cls, zero_state_errors, one_state_errors, repetitions, timestamp, **kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_estimate_single_qubit_readout_errors_with_noise():


def test_estimate_parallel_readout_errors_no_noise():
qubits = cirq.LineQubit.range(10)
qubits = [cirq.GridQubit(i, 0) for i in range(10)]
sampler = cirq.Simulator()
repetitions = 1000
result = cirq.estimate_parallel_single_qubit_readout_errors(
Expand All @@ -97,6 +97,8 @@ def test_estimate_parallel_readout_errors_no_noise():
assert result.one_state_errors == {q: 0 for q in qubits}
assert result.repetitions == repetitions
assert isinstance(result.timestamp, float)
_ = result.plot_integrated_histogram()
_, _ = result.plot_heatmap()


def test_estimate_parallel_readout_errors_all_zeros():
Expand Down

0 comments on commit 2ef1909

Please sign in to comment.