Skip to content

Commit 57d287e

Browse files
authored
Merge pull request #172 from desy-ml/fix-cavity-tracking-issue
Fix cavity tracking issue introduced in recent updates
2 parents b9f971d + 1083409 commit 57d287e

File tree

3 files changed

+57
-6
lines changed

3 files changed

+57
-6
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
### 🚨 Breaking Changes
66

7-
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170) (@jank324, @cr-xu)
7+
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172) (@jank324, @cr-xu)
88

99
### 🚀 Features
1010

cheetah/accelerator/cavity.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ def _track_beam(self, incoming: Beam) -> Beam:
155155
outgoing_energy.unsqueeze(-1) * beta1.unsqueeze(-1)
156156
) * (
157157
torch.cos(
158-
incoming.particles[..., 4]
158+
-1
159+
* incoming.particles[..., 4]
159160
* beta0.unsqueeze(-1)
160161
* k.unsqueeze(-1)
161162
+ phi.unsqueeze(-1)

tests/test_compare_ocelot.py

+54-4
Original file line numberDiff line numberDiff line change
@@ -698,13 +698,63 @@ def test_cavity():
698698

699699
# Compare
700700
assert np.isclose(outgoing_beam.beta_x.cpu().numpy(), derived_twiss.beta_x)
701-
assert np.isclose(
702-
outgoing_beam.alpha_x.cpu().numpy(), derived_twiss.alpha_x, rtol=2e-5
703-
)
701+
assert np.isclose(outgoing_beam.alpha_x.cpu().numpy(), derived_twiss.alpha_x)
704702
assert np.isclose(outgoing_beam.beta_y.cpu().numpy(), derived_twiss.beta_y)
703+
assert np.isclose(outgoing_beam.alpha_y.cpu().numpy(), derived_twiss.alpha_y)
705704
assert np.isclose(
706-
outgoing_beam.alpha_y.cpu().numpy(), derived_twiss.alpha_y, rtol=2e-5
705+
outgoing_beam.total_charge.cpu().numpy(), np.sum(outgoing_parray.q_array)
706+
)
707+
assert np.allclose(
708+
outgoing_beam.particles[:, :, 5].cpu().numpy(),
709+
outgoing_parray.rparticles.transpose()[:, 5],
710+
)
711+
assert np.allclose(
712+
outgoing_beam.particles[:, :, 4].cpu().numpy(),
713+
outgoing_parray.rparticles.transpose()[:, 4],
714+
)
715+
716+
717+
def test_cavity_non_zero_phase():
718+
"""Compare tracking through a cavity with a phase offset."""
719+
# Ocelot
720+
tws = ocelot.Twiss()
721+
tws.beta_x = 5.91253677
722+
tws.alpha_x = 3.55631308
723+
tws.beta_y = 5.91253677
724+
tws.alpha_y = 3.55631308
725+
tws.emit_x = 3.494768647122823e-09
726+
tws.emit_y = 3.497810737006068e-09
727+
tws.gamma_x = (1 + tws.alpha_x**2) / tws.beta_x
728+
tws.gamma_y = (1 + tws.alpha_y**2) / tws.beta_y
729+
tws.E = 6e-3
730+
731+
p_array = ocelot.generate_parray(tws=tws, charge=5e-9)
732+
733+
cell = [ocelot.Cavity(l=1.0377, v=0.01815975, freq=1.3e9, phi=30.0)]
734+
lattice = ocelot.MagneticLattice(cell)
735+
navigator = ocelot.Navigator(lattice=lattice)
736+
737+
_, outgoing_parray = ocelot.track(lattice, deepcopy(p_array), navigator)
738+
derived_twiss = ocelot.cpbd.beam.get_envelope(outgoing_parray)
739+
740+
# Cheetah
741+
incoming_beam = cheetah.ParticleBeam.from_ocelot(
742+
parray=p_array, dtype=torch.float64
743+
)
744+
cheetah_cavity = cheetah.Cavity(
745+
length=torch.tensor([1.0377]),
746+
voltage=torch.tensor([0.01815975e9]),
747+
frequency=torch.tensor([1.3e9]),
748+
phase=torch.tensor([30.0]),
749+
dtype=torch.float64,
707750
)
751+
outgoing_beam = cheetah_cavity.track(incoming_beam)
752+
753+
# Compare
754+
assert np.isclose(outgoing_beam.beta_x.cpu().numpy(), derived_twiss.beta_x)
755+
assert np.isclose(outgoing_beam.alpha_x.cpu().numpy(), derived_twiss.alpha_x)
756+
assert np.isclose(outgoing_beam.beta_y.cpu().numpy(), derived_twiss.beta_y)
757+
assert np.isclose(outgoing_beam.alpha_y.cpu().numpy(), derived_twiss.alpha_y)
708758
assert np.isclose(
709759
outgoing_beam.total_charge.cpu().numpy(), np.sum(outgoing_parray.q_array)
710760
)

0 commit comments

Comments
 (0)