Skip to content

Commit c96f3c4

Browse files
authored
Fix twiss plot (#213)
* Add a test that finds Twiss plotting error * Fix error caused by Twiss plot not being updated since vectorisation * Add Twiss plot fix to changelog * Implement vectorised linspace utility function * Fix `make_linspaced` beam creation * Fix plotting in 0-dimensional case * Add reference particle plotting test case * Fix 0-dimensional case for Twiss plot * Add tests for plotting of vectorised models * Fix 2D reference particle plot but bread 0D one * Fix 0D particle trace plot * Fix 2D Twiss plot test * Remove debugging code * Another fix of debugging leftovers * Slightly more consistent parameter naming
1 parent dcb6318 commit c96f3c4

26 files changed

+476
-235
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
66

77
### 🚨 Breaking Changes
88

9-
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #215, #218, #229, #233, #258) (@jank324, @cr-xu, @hespe, @roussel-ryan)
9+
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233, #258) (@jank324, @cr-xu, @hespe, @roussel-ryan)
1010
- The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu)
1111
- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan)
1212

cheetah/accelerator/aperture.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,16 @@ def split(self, resolution: torch.Tensor) -> list[Element]:
113113
# TODO: Implement splitting for aperture properly, for now just return self
114114
return [self]
115115

116-
def plot(self, ax: plt.Axes, s: float) -> None:
116+
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
117+
plot_s = s[vector_idx] if s.dim() > 0 else s
118+
117119
alpha = 1 if self.is_active else 0.2
118120
height = 0.4
119121

120122
dummy_length = 0.0
121123

122124
patch = Rectangle(
123-
(s, 0), dummy_length, height, color="tab:pink", alpha=alpha, zorder=2
125+
(plot_s, 0), dummy_length, height, color="tab:pink", alpha=alpha, zorder=2
124126
)
125127
ax.add_patch(patch)
126128

cheetah/accelerator/bpm.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@ def track(self, incoming: Beam) -> Beam:
5151
def split(self, resolution: torch.Tensor) -> list[Element]:
5252
return [self]
5353

54-
def plot(self, ax: plt.Axes, s: float) -> None:
54+
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
55+
plot_s = s[vector_idx] if s.dim() > 0 else s
56+
5557
alpha = 1 if self.is_active else 0.2
5658
patch = Rectangle(
57-
(s, -0.3), 0, 0.3 * 2, color="darkkhaki", alpha=alpha, zorder=2
59+
(plot_s, -0.3), 0, 0.3 * 2, color="darkkhaki", alpha=alpha, zorder=2
5860
)
5961
ax.add_patch(patch)
6062

cheetah/accelerator/cavity.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,15 @@ def split(self, resolution: torch.Tensor) -> list[Element]:
341341
# element itself
342342
return [self]
343343

344-
def plot(self, ax: plt.Axes, s: float) -> None:
344+
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
345+
plot_s = s[vector_idx] if s.dim() > 0 else s
346+
plot_length = self.length[vector_idx] if self.length.dim() > 0 else self.length
347+
345348
alpha = 1 if self.is_active else 0.2
346349
height = 0.4
347350

348351
patch = Rectangle(
349-
(s, 0), self.length[0], height, color="gold", alpha=alpha, zorder=2
352+
(plot_s, 0), plot_length, height, color="gold", alpha=alpha, zorder=2
350353
)
351354
ax.add_patch(patch)
352355

cheetah/accelerator/custom_transfer_map.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,11 @@ def defining_features(self) -> list[str]:
102102
def split(self, resolution: torch.Tensor) -> list[Element]:
103103
return [self]
104104

105-
def plot(self, ax: plt.Axes, s: float) -> None:
105+
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
106+
plot_s = s[vector_idx] if s.dim() > 0 else s
107+
plot_length = self.length[vector_idx] if self.length.dim() > 0 else self.length
108+
106109
height = 0.4
107110

108-
patch = Rectangle((s, 0), self.length[0], height, color="tab:olive", zorder=2)
111+
patch = Rectangle((plot_s, 0), plot_length, height, color="tab:olive", zorder=2)
109112
ax.add_patch(patch)

cheetah/accelerator/dipole.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -480,11 +480,15 @@ def defining_features(self) -> list[str]:
480480
"tracking_method",
481481
]
482482

483-
def plot(self, ax: plt.Axes, s: float) -> None:
483+
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
484+
plot_s = s[vector_idx] if s.dim() > 0 else s
485+
plot_length = self.length[vector_idx] if self.length.dim() > 0 else self.length
486+
plot_angle = self.angle[vector_idx] if self.angle.dim() > 0 else self.angle
487+
484488
alpha = 1 if self.is_active else 0.2
485-
height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1)
489+
height = 0.8 * (np.sign(plot_angle) if self.is_active else 1)
486490

487491
patch = Rectangle(
488-
(s, 0), self.length[0], height, color="tab:green", alpha=alpha, zorder=2
492+
(plot_s, 0), plot_length, height, color="tab:green", alpha=alpha, zorder=2
489493
)
490494
ax.add_patch(patch)

cheetah/accelerator/drift.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def split(self, resolution: torch.Tensor) -> list[Element]:
136136
for i in range(num_splits)
137137
]
138138

139-
def plot(self, ax: plt.Axes, s: float) -> None:
139+
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
140140
pass
141141

142142
@property

cheetah/accelerator/element.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,16 @@ def split(self, resolution: torch.Tensor) -> list["Element"]:
126126
raise NotImplementedError
127127

128128
@abstractmethod
129-
def plot(self, ax: plt.Axes, s: float) -> None:
129+
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
130130
"""
131131
Plot a representation of this element into a `matplotlib` Axes at position `s`.
132132
133133
:param ax: Axes to plot the representation into.
134134
:param s: Position of the object along s in meters.
135+
:param vector_idx: Index of the vector dimension to plot. If the model has more
136+
than one vector dimension, this can be used to select a specific one. In the
137+
case of present vector dimension but no index provided, the first one is
138+
used by default.
135139
"""
136140
raise NotImplementedError
137141

cheetah/accelerator/horizontal_corrector.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,16 @@ def split(self, resolution: torch.Tensor) -> list[Element]:
8080
for i in range(num_splits)
8181
]
8282

83-
def plot(self, ax: plt.Axes, s: float) -> None:
83+
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
84+
plot_s = s[vector_idx] if s.dim() > 0 else s
85+
plot_length = self.length[vector_idx] if self.length.dim() > 0 else self.length
86+
plot_angle = self.angle[vector_idx] if self.angle.dim() > 0 else self.angle
87+
8488
alpha = 1 if self.is_active else 0.2
85-
height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1)
89+
height = 0.8 * (np.sign(plot_angle) if self.is_active else 1)
8690

8791
patch = Rectangle(
88-
(s, 0), self.length[0], height, color="tab:blue", alpha=alpha, zorder=2
92+
(plot_s, 0), plot_length, height, color="tab:blue", alpha=alpha, zorder=2
8993
)
9094
ax.add_patch(patch)
9195

cheetah/accelerator/marker.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def is_skippable(self) -> bool:
3737
def split(self, resolution: torch.Tensor) -> list[Element]:
3838
return [self]
3939

40-
def plot(self, ax: plt.Axes, s: float) -> None:
40+
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
4141
# Do nothing on purpose. Maybe later we decide markers should be shown, but for
4242
# now they are invisible.
4343
pass

cheetah/accelerator/quadrupole.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -217,11 +217,15 @@ def split(self, resolution: torch.Tensor) -> list[Element]:
217217
for i in range(num_splits)
218218
]
219219

220-
def plot(self, ax: plt.Axes, s: float) -> None:
220+
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
221+
plot_k1 = self.k1[vector_idx] if self.k1.dim() > 0 else self.k1
222+
plot_s = s[vector_idx] if s.dim() > 0 else s
223+
plot_length = self.length[vector_idx] if self.length.dim() > 0 else self.length
224+
221225
alpha = 1 if self.is_active else 0.2
222-
height = 0.8 * (np.sign(self.k1[0]) if self.is_active else 1)
226+
height = 0.8 * (np.sign(plot_k1) if self.is_active else 1)
223227
patch = Rectangle(
224-
(s, 0), self.length[0], height, color="tab:red", alpha=alpha, zorder=2
228+
(plot_s, 0), plot_length, height, color="tab:red", alpha=alpha, zorder=2
225229
)
226230
ax.add_patch(patch)
227231

cheetah/accelerator/screen.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,13 @@ def set_read_beam(self, value: Beam) -> None:
297297
def split(self, resolution: torch.Tensor) -> list[Element]:
298298
return [self]
299299

300-
def plot(self, ax: plt.Axes, s: float) -> None:
300+
def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None:
301+
plot_s = s[vector_idx] if s.dim() > 0 else s
302+
301303
alpha = 1 if self.is_active else 0.2
304+
302305
patch = Rectangle(
303-
(s, -0.6), 0, 0.6 * 2, color="tab:green", alpha=alpha, zorder=2
306+
(plot_s, -0.6), 0, 0.6 * 2, color="tab:green", alpha=alpha, zorder=2
304307
)
305308
ax.add_patch(patch)
306309

0 commit comments

Comments
 (0)