Skip to content

Commit 09cd0d3

Browse files
committed
Fix broken references to mass_eV
1 parent d74fc78 commit 09cd0d3

10 files changed

+26
-17
lines changed

cheetah/accelerator/cavity.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _track_beam(self, incoming: Beam) -> Beam:
105105

106106
phi = torch.deg2rad(self.phase)
107107

108-
tm = self.transfer_map(incoming.energy, incoming.mass_eV)
108+
tm = self.transfer_map(incoming.energy, incoming.species.mass_eV)
109109
if isinstance(incoming, ParameterBeam):
110110
outgoing_mu = torch.matmul(tm, incoming._mu.unsqueeze(-1)).squeeze(-1)
111111
outgoing_cov = torch.matmul(
@@ -153,7 +153,7 @@ def _track_beam(self, incoming: Beam) -> Beam:
153153
- torch.cos(phi).unsqueeze(-1)
154154
)
155155

156-
dgamma = self.voltage / incoming.mass_eV
156+
dgamma = self.voltage / incoming.species.mass_eV
157157
if torch.any(delta_energy > 0):
158158
T566 = (
159159
self.length

cheetah/accelerator/custom_transfer_map.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,25 @@ def from_merging_elements(
6363
)
6464

6565
device = (
66-
elements[0].transfer_map(incoming_beam.energy, incoming_beam.mass_eV).device
66+
elements[0]
67+
.transfer_map(incoming_beam.energy, incoming_beam.species.mass_eV)
68+
.device
6769
)
6870
dtype = (
69-
elements[0].transfer_map(incoming_beam.energy, incoming_beam.mass_eV).dtype
71+
elements[0]
72+
.transfer_map(incoming_beam.energy, incoming_beam.species.mass_eV)
73+
.dtype
7074
)
7175

7276
tm = torch.eye(7, device=device, dtype=dtype).repeat(
7377
(*incoming_beam.energy.shape, 1, 1)
7478
)
7579
for element in elements:
7680
tm = torch.matmul(
77-
element.transfer_map(incoming_beam.energy, incoming_beam.mass_eV), tm
81+
element.transfer_map(
82+
incoming_beam.energy, incoming_beam.species.mass_eV
83+
),
84+
tm,
7885
)
7986
incoming_beam = element.track(incoming_beam)
8087

cheetah/accelerator/dipole.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
186186
py = incoming.py
187187
tau = incoming.tau
188188
delta = incoming.p
189-
mc2 = incoming.mass_eV
189+
mc2 = incoming.species.mass_eV
190190

191191
z, pz, p0c = bmadx.cheetah_to_bmad_z_pz(tau, delta, incoming.energy, mc2)
192192

cheetah/accelerator/drift.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,18 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
9090
delta = incoming.p
9191

9292
z, pz, p0c = bmadx.cheetah_to_bmad_z_pz(
93-
tau, delta, incoming.energy, incoming.mass_eV
93+
tau, delta, incoming.energy, incoming.species.mass_eV
9494
)
9595

9696
# Begin Bmad-X tracking
9797
x, y, z = bmadx.track_a_drift(
98-
self.length, x, px, y, py, z, pz, p0c, incoming.mass_eV
98+
self.length, x, px, y, py, z, pz, p0c, incoming.species.mass_eV
9999
)
100100
# End of Bmad-X tracking
101101

102102
# Convert back to Cheetah coordinates
103103
tau, delta, ref_energy = bmadx.bmad_to_cheetah_z_pz(
104-
z, pz, p0c, incoming.mass_eV
104+
z, pz, p0c, incoming.species.mass_eV
105105
)
106106

107107
# Broadcast to align their shapes so that they can be stacked

cheetah/accelerator/element.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def track(self, incoming: Beam) -> Beam:
6464
:return: Beam of particles exiting the element.
6565
"""
6666
if isinstance(incoming, ParameterBeam):
67-
tm = self.transfer_map(incoming.energy, incoming.mass_eV)
67+
tm = self.transfer_map(incoming.energy, incoming.species.mass_eV)
6868
mu = torch.matmul(tm, incoming._mu.unsqueeze(-1)).squeeze(-1)
6969
cov = torch.matmul(tm, torch.matmul(incoming._cov, tm.transpose(-2, -1)))
7070
return ParameterBeam(
@@ -76,7 +76,7 @@ def track(self, incoming: Beam) -> Beam:
7676
dtype=mu.dtype,
7777
)
7878
elif isinstance(incoming, ParticleBeam):
79-
tm = self.transfer_map(incoming.energy, incoming.mass_eV)
79+
tm = self.transfer_map(incoming.energy, incoming.species.mass_eV)
8080
new_particles = torch.matmul(incoming.particles, tm.transpose(-2, -1))
8181
return ParticleBeam(
8282
new_particles,

cheetah/accelerator/quadrupole.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
114114
py = incoming.py
115115
tau = incoming.tau
116116
delta = incoming.p
117-
mc2 = incoming.mass_eV
117+
mc2 = incoming.species.mass_eV
118118

119119
z, pz, p0c = bmadx.cheetah_to_bmad_z_pz(tau, delta, incoming.energy, mc2)
120120

cheetah/accelerator/transverse_deflecting_cavity.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
118118
py = incoming.py
119119
tau = incoming.tau
120120
delta = incoming.p
121-
mc2 = incoming.mass_eV
121+
mc2 = incoming.species.mass_eV
122122

123123
z, pz, p0c = bmadx.cheetah_to_bmad_z_pz(tau, delta, incoming.energy, mc2)
124124

cheetah/particles/particle_beam.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1547,7 +1547,7 @@ def energies(self) -> torch.Tensor:
15471547
@property
15481548
def momenta(self) -> torch.Tensor:
15491549
"""Momenta of the individual particles."""
1550-
return torch.sqrt(self.energies**2 - self.mass_eV**2)
1550+
return torch.sqrt(self.energies**2 - self.species.mass_eV**2)
15511551

15521552
def clone(self) -> "ParticleBeam":
15531553
return ParticleBeam(

tests/test_cavity.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def test_vectorized_cavity_zero_voltage(voltage):
7272

7373
outgoing = cavity.track(incoming)
7474

75-
assert not torch.isnan(cavity.transfer_map(incoming.energy, incoming.mass_eV)).any()
75+
assert not torch.isnan(
76+
cavity.transfer_map(incoming.energy, incoming.species.mass_eV)
77+
).any()
7678

7779
assert not torch.isnan(outgoing.sigma_x).any()
7880
assert not torch.isnan(outgoing.sigma_y).any()

tests/test_speed_optimizations.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,10 @@ def test_skippable_elements_reset():
232232
)
233233

234234
original_tm = original_segment.elements[2].transfer_map(
235-
energy=incoming_beam.energy, particle_mass_eV=incoming_beam.mass_eV
235+
energy=incoming_beam.energy, particle_mass_eV=incoming_beam.species.mass_eV
236236
)
237237
merged_tm = merged_segment.elements[2].transfer_map(
238-
energy=incoming_beam.energy, particle_mass_eV=incoming_beam.mass_eV
238+
energy=incoming_beam.energy, particle_mass_eV=incoming_beam.species.mass_eV
239239
)
240240

241241
assert torch.allclose(original_tm, merged_tm)

0 commit comments

Comments
 (0)