Skip to content

Commit

Permalink
Plot State City Visualization Fixes (#10590) (#11116)
Browse files Browse the repository at this point in the history
* fix zordering and labels outside of plot

* lint and reno

* Changed reference image for state city

* correct visual of negative real value bars

* added release notes for negative real bars fix

* remove debug print statement in plot_state_city

* fix with tox eblack

* append rho to title

Co-authored-by: Luciano Bello <bel@zurich.ibm.com>

---------

Co-authored-by: Luciano Bello <bel@zurich.ibm.com>
(cherry picked from commit 4b5546f)

Co-authored-by: AlexanderGroeger <46076580+AlexanderGroeger@users.noreply.github.com>
  • Loading branch information
mergify[bot] and AlexanderGroeger authored Oct 30, 2023
1 parent b563476 commit 2e861ba
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 103 deletions.
199 changes: 96 additions & 103 deletions qiskit/visualization/state_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def plot_state_city(
plot_state_city(state, alpha=0.6)
"""
import matplotlib.colors as mcolors
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

Expand All @@ -463,8 +464,7 @@ def plot_state_city(
column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
row_names = [bin(i)[2:].zfill(num) for i in range(2**num)]

lx = len(datareal[0]) # Work out matrix dimensions
ly = len(datareal[:, 0])
ly, lx = datareal.shape[:2]
xpos = np.arange(0, lx, 1) # Set up a mesh of positions
ypos = np.arange(0, ly, 1)
xpos, ypos = np.meshgrid(xpos + 0.25, ypos + 0.25)
Expand All @@ -479,22 +479,21 @@ def plot_state_city(
dzi = dataimag.flatten()

if color is None:
color = ["#648fff", "#648fff"]
real_color, imag_color = "#648fff", "#648fff"
else:
if len(color) != 2:
raise ValueError("'color' must be a list of len=2.")
if color[0] is None:
color[0] = "#648fff"
if color[1] is None:
color[1] = "#648fff"
real_color = "#648fff" if color[0] is None else color[0]
imag_color = "#648fff" if color[1] is None else color[1]
if ax_real is None and ax_imag is None:
# set default figure size
if figsize is None:
figsize = (15, 5)
figsize = (16, 8)

fig = plt.figure(figsize=figsize, facecolor="w")
ax1 = fig.add_subplot(1, 2, 1, projection="3d", computed_zorder=False)
ax2 = fig.add_subplot(1, 2, 2, projection="3d", computed_zorder=False)

fig = plt.figure(figsize=figsize)
ax1 = fig.add_subplot(1, 2, 1, projection="3d")
ax2 = fig.add_subplot(1, 2, 2, projection="3d")
elif ax_real is not None:
fig = ax_real.get_figure()
ax1 = ax_real
Expand All @@ -504,109 +503,103 @@ def plot_state_city(
ax1 = None
ax2 = ax_imag

max_dzr = max(dzr)
min_dzr = min(dzr)
min_dzi = np.min(dzi)
fig.tight_layout()

max_dzr = np.max(dzr)
max_dzi = np.max(dzi)

# There seems to be a rounding error in which some zero bars are negative
dzr = np.clip(dzr, 0, None)
# Figure scaling variables since fig.tight_layout won't work
fig_width, fig_height = fig.get_size_inches()
max_plot_size = min(fig_width / 2.25, fig_height)
max_font_size = int(3 * max_plot_size)
max_zoom = 10 / (10 + np.sqrt(max_plot_size))

if ax1 is not None:
fc1 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzr, color[0])
for idx, cur_zpos in enumerate(zpos):
if dzr[idx] > 0:
zorder = 2
else:
zorder = 0
b1 = ax1.bar3d(
xpos[idx],
ypos[idx],
cur_zpos,
dx[idx],
dy[idx],
dzr[idx],
for (ax, dz, col, zlabel) in (
(ax1, dzr, real_color, "Real"),
(ax2, dzi, imag_color, "Imaginary"),
):

if ax is None:
continue

max_dz = np.max(dz)
min_dz = np.min(dz)

if isinstance(col, str) and col.startswith("#"):
col = mcolors.to_rgba_array(col)

dzn = dz < 0
if np.any(dzn):
fc = generate_facecolors(
xpos[dzn], ypos[dzn], zpos[dzn], dx[dzn], dy[dzn], dz[dzn], col
)
negative_bars = ax.bar3d(
xpos[dzn],
ypos[dzn],
zpos[dzn],
dx[dzn],
dy[dzn],
dz[dzn],
alpha=alpha,
zorder=zorder,
zorder=0.625,
)
b1.set_facecolors(fc1[6 * idx : 6 * idx + 6])

xlim, ylim = ax1.get_xlim(), ax1.get_ylim()
x = [xlim[0], xlim[1], xlim[1], xlim[0]]
y = [ylim[0], ylim[0], ylim[1], ylim[1]]
z = [0, 0, 0, 0]
verts = [list(zip(x, y, z))]

pc1 = Poly3DCollection(verts, alpha=0.15, facecolor="k", linewidths=1, zorder=1)

if min(dzr) < 0 < max(dzr):
ax1.add_collection3d(pc1)
ax1.set_xticks(np.arange(0.5, lx + 0.5, 1))
ax1.set_yticks(np.arange(0.5, ly + 0.5, 1))
if max_dzr != min_dzr:
ax1.axes.set_zlim3d(np.min(dzr), max(np.max(dzr) + 1e-9, max_dzi))
else:
if min_dzr == 0:
ax1.axes.set_zlim3d(np.min(dzr), max(np.max(dzr) + 1e-9, np.max(dzi)))
else:
ax1.axes.set_zlim3d(auto=True)
ax1.get_autoscalez_on()
ax1.xaxis.set_ticklabels(row_names, fontsize=14, rotation=45, ha="right", va="top")
ax1.yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5, ha="left", va="center")
ax1.set_zlabel("Re[$\\rho$]", fontsize=14)
for tick in ax1.zaxis.get_major_ticks():
tick.label1.set_fontsize(14)

if ax2 is not None:
fc2 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzi, color[1])
for idx, cur_zpos in enumerate(zpos):
if dzi[idx] > 0:
zorder = 2
else:
zorder = 0
b2 = ax2.bar3d(
xpos[idx],
ypos[idx],
cur_zpos,
dx[idx],
dy[idx],
dzi[idx],
negative_bars.set_facecolor(fc)

if min_dz < 0 < max_dz:
xlim, ylim = [0, lx], [0, ly]
verts = [list(zip(xlim + xlim[::-1], np.repeat(ylim, 2), [0] * 4))]
plane = Poly3DCollection(verts, alpha=0.25, facecolor="k", linewidths=1)
plane.set_zorder(0.75)
ax.add_collection3d(plane)

dzp = dz >= 0
if np.any(dzp):
fc = generate_facecolors(
xpos[dzp], ypos[dzp], zpos[dzp], dx[dzp], dy[dzp], dz[dzp], col
)
positive_bars = ax.bar3d(
xpos[dzp],
ypos[dzp],
zpos[dzp],
dx[dzp],
dy[dzp],
dz[dzp],
alpha=alpha,
zorder=zorder,
zorder=0.875,
)
b2.set_facecolors(fc2[6 * idx : 6 * idx + 6])

xlim, ylim = ax2.get_xlim(), ax2.get_ylim()
x = [xlim[0], xlim[1], xlim[1], xlim[0]]
y = [ylim[0], ylim[0], ylim[1], ylim[1]]
z = [0, 0, 0, 0]
verts = [list(zip(x, y, z))]

pc2 = Poly3DCollection(verts, alpha=0.2, facecolor="k", linewidths=1, zorder=1)

if min(dzi) < 0 < max(dzi):
ax2.add_collection3d(pc2)
ax2.set_xticks(np.arange(0.5, lx + 0.5, 1))
ax2.set_yticks(np.arange(0.5, ly + 0.5, 1))
if min_dzi != max_dzi:
eps = 0
ax2.axes.set_zlim3d(np.min(dzi), max(np.max(dzr) + 1e-9, np.max(dzi) + eps))
positive_bars.set_facecolor(fc)

ax.set_title(f"{zlabel} Amplitude (ρ)", fontsize=max_font_size)

ax.set_xticks(np.arange(0.5, lx + 0.5, 1))
ax.set_yticks(np.arange(0.5, ly + 0.5, 1))
if max_dz != min_dz:
ax.axes.set_zlim3d(min_dz, max(max_dzr + 1e-9, max_dzi))
else:
if min_dzi == 0:
ax2.set_zticks([0])
eps = 1e-9
ax2.axes.set_zlim3d(np.min(dzi), max(np.max(dzr) + 1e-9, np.max(dzi) + eps))
if min_dz == 0:
ax.axes.set_zlim3d(min_dz, max(max_dzr + 1e-9, max_dzi))
else:
ax2.axes.set_zlim3d(auto=True)
ax.axes.set_zlim3d(auto=True)
ax.get_autoscalez_on()

ax.xaxis.set_ticklabels(
row_names, fontsize=max_font_size, rotation=45, ha="right", va="top"
)
ax.yaxis.set_ticklabels(
column_names, fontsize=max_font_size, rotation=-22.5, ha="left", va="center"
)

for tick in ax.zaxis.get_major_ticks():
tick.label1.set_fontsize(max_font_size)
tick.label1.set_horizontalalignment("left")
tick.label1.set_verticalalignment("bottom")

ax2.xaxis.set_ticklabels(row_names, fontsize=14, rotation=45, ha="right", va="top")
ax2.yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5, ha="left", va="center")
ax2.set_zlabel("Im[$\\rho$]", fontsize=14)
for tick in ax2.zaxis.get_major_ticks():
tick.label1.set_fontsize(14)
ax2.get_autoscalez_on()
ax.set_box_aspect(aspect=(4, 4, 4), zoom=max_zoom)
ax.set_xmargin(0)
ax.set_ymargin(0)

fig.suptitle(title, fontsize=16)
fig.suptitle(title, fontsize=max_font_size * 1.25)
fig.subplots_adjust(top=0.9, bottom=0, left=0, right=1, hspace=0, wspace=0)
if ax_real is None and ax_imag is None:
matplotlib_close_if_inline(fig)
if filename is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
fixes:
- |
Adjusted zoom, fontsize, and margins to fit the plot better for more figure
sizes.
- |
Corrected the Z-ordering behavior of bars and the zero-amplitude plane.
- |
Corrected display of negative real value bars
Binary file modified test/visual/mpl/graph/references/state_city.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 2e861ba

Please sign in to comment.