Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions apace/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from matplotlib.offsetbox import AnchoredOffsetbox, TextArea, VPacker
from matplotlib.path import Path
from matplotlib.widgets import Slider
from matplotlib.transforms import blended_transform_factory

from .classes import (
Base,
Expand Down Expand Up @@ -67,22 +68,35 @@ def draw_elements(
location: str = "top",
):
"""Draw the elements of a lattice onto a matplotlib axes."""
transform = blended_transform_factory(ax.transData, ax.transAxes)
x_min, x_max = ax.get_xlim()
y_min, y_max = ax.get_ylim()
rect_height = 0.05 * (y_max - y_min)
y_span = y_max - y_min
rect_height = 0.05
if location == "top":
y0 = y_max = y_max + rect_height
y0 = 1
elif location == "center":
y0 = 0.5
else:
y0 = y_min - rect_height
y_min -= 3 * rect_height
plt.hlines(y0, x_min, x_max, color="black", linewidth=1)
ax.set_ylim(y_min, y_max)
y0 = 2 * rect_height
ax.hlines(
[2 * rect_height, 4 * rect_height],
xmin=x_min,
xmax=x_max,
color="black",
linewidth=1,
transform=transform,
)
yticks = ax.get_yticks()
ax.set_ylim(y_min + y_span * (1 - 1 / (1 - 4 * rect_height)), y_max)
ax.set_yticks(yticks)

sign = -1
start = end = 0
for element, group in groupby(lattice.sequence):
start = end
end += element.length * sum(1 for _ in group)
group_length = sum(1 for _ in group)
end += element.length * group_length
if end <= x_min:
continue
elif start >= x_max:
Expand All @@ -91,6 +105,8 @@ def draw_elements(
try:
color = ELEMENT_COLOR[type(element)]
except KeyError:
# don't switch sign after consecutive drifts
sign *= 1 - 2 * (group_length & 1)
continue

y0_local = y0
Expand All @@ -102,13 +118,14 @@ def draw_elements(
(max(start, x_min), y0_local - 0.5 * rect_height),
min(end, x_max) - max(start, x_min),
rect_height,
facecolor=color,
color=color,
clip_on=False,
zorder=10,
transform=transform,
)
)
if labels and type(element) in {Dipole, Quadrupole}:
sign = -sign

if labels:
ax.annotate(
element.name,
xy=(0.5 * (start + end), y0 + sign * rect_height),
Expand All @@ -117,6 +134,7 @@ def draw_elements(
va="bottom" if sign > 0 else "top",
annotation_clip=False,
zorder=11,
xycoords=transform,
)


Expand All @@ -138,7 +156,7 @@ def draw_sub_lattices(

if labels:
y_min, y_max = ax.get_ylim()
height = 0.08 * (y_max - y_min)
height = 0.2 * (y_max - y_min)
if location == "top":
y0 = y_max - height
else:
Expand Down