Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix vectorized tracking for TransverseDeflectingCavity #278

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
- Moving `Element`s and `Beam`s to a different `device` and changing their `dtype` like with any `torch.nn.Module` is now possible (see #209) (@jank324)
- `Quadrupole` now supports tracking with Cheetah's matrix-based method or with Bmad's more accurate method (see #153) (@jp-ga, @jank324)
- Port Bmad-X tracking methods to Cheetah for `Quadrupole`, `Drift`, and `Dipole` (see #153, #240) (@jp-ga, @jank324)
- Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240) (@jp-ga)
- Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240, #278) (@jp-ga, @cr-xu, @jank324)
- `Dipole` and `RBend` now take a focusing moment `k1` (see #235, #247) (@hespe)
- Implement a converter for lattice files imported from Elegant (see #222, #251, #273) (@hespe)

Expand Down
14 changes: 7 additions & 7 deletions cheetah/accelerator/transverse_deflecting_cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,15 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:

voltage = self.voltage / p0c
k_rf = 2 * torch.pi * self.frequency / speed_of_light
# Phase that the particle sees
phase = (
2
* torch.pi
* (
self.phase
self.phase.unsqueeze(-1)
- (
bmadx.particle_rf_time(z, pz, p0c, electron_mass_eV)
* self.frequency
* self.frequency.unsqueeze(-1)
)
)
)
Expand All @@ -173,16 +174,15 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
# two separate variables
px = px + voltage.unsqueeze(-1) * torch.sin(phase)

beta = (
beta_old = (
(1 + pz)
* p0c.unsqueeze(-1)
/ torch.sqrt(((1 + pz) * p0c.unsqueeze(-1)) ** 2 + electron_mass_eV**2)
)
beta_old = beta
E_old = (1 + pz) * p0c.unsqueeze(-1) / beta_old
E_new = E_old + voltage.unsqueeze(-1) * torch.cos(
phase
) * k_rf * x * p0c.unsqueeze(-1)
E_new = E_old + voltage.unsqueeze(-1) * torch.cos(phase) * k_rf.unsqueeze(
-1
) * x * p0c.unsqueeze(-1)
pc = torch.sqrt(E_new**2 - electron_mass_eV**2)
beta = pc / E_new

Expand Down
96 changes: 96 additions & 0 deletions tests/test_transverse_deflecting_cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,99 @@ def test_transverse_deflecting_cavity_bmadx_tracking(dtype):
atol=1e-14 if dtype == torch.float64 else 0.00001,
rtol=1e-14 if dtype == torch.float64 else 1e-6,
)


def test_transverse_deflecting_cavity_energy_length_vectorization():
"""
Test that vectorised tracking through a TDC throws now exception and outputs the
correct shape, when the input beam's energy and the TDC's length are vectorised.
"""
incoming_beam = cheetah.ParticleBeam.from_parameters(
num_particles=torch.tensor(10_000),
sigma_px=torch.tensor(2e-7),
sigma_py=torch.tensor(2e-7),
energy=torch.tensor([50e6, 60e6]),
)
tdc = cheetah.TransverseDeflectingCavity(
length=torch.tensor(1.0),
voltage=torch.tensor([[1e7], [2e7], [3e7]]),
phase=torch.tensor(0.4),
frequency=torch.tensor(1e9),
tracking_method="bmadx",
)

outgoing_beam = tdc.track(incoming_beam)

assert outgoing_beam.particles.shape[:-2] == torch.Size([3, 2])


def test_transverse_deflecting_cavity_energy_phase_vectorization():
"""
Test that vectorised tracking through a TDC throws now exception and outputs the
correct shape, when the input beam's energy and the TDC's phase are vectorised.
"""
incoming_beam = cheetah.ParticleBeam.from_parameters(
num_particles=torch.tensor(10_000),
sigma_px=torch.tensor(2e-7),
sigma_py=torch.tensor(2e-7),
energy=torch.tensor([50e6, 60e6]),
)
tdc = cheetah.TransverseDeflectingCavity(
length=torch.tensor(1.0),
voltage=torch.tensor(1e7),
phase=torch.tensor([[0.6], [0.5], [0.4]]),
frequency=torch.tensor(1e9),
tracking_method="bmadx",
)

outgoing_beam = tdc.track(incoming_beam)

assert outgoing_beam.particles.shape[:-2] == torch.Size([3, 2])


def test_transverse_deflecting_cavity_energy_frequency_vectorization():
"""
Test that vectorised tracking through a TDC throws now exception and outputs the
correct shape, when the input beam's energy and the TDC's frequency are vectorised.
"""
incoming_beam = cheetah.ParticleBeam.from_parameters(
num_particles=torch.tensor(10_000),
sigma_px=torch.tensor(2e-7),
sigma_py=torch.tensor(2e-7),
energy=torch.tensor([50e6, 60e6]),
)
tdc3 = cheetah.TransverseDeflectingCavity(
length=torch.tensor(1.0),
voltage=torch.tensor(1e7),
phase=torch.tensor(0.4),
frequency=torch.tensor([[1e9], [2e9], [3e9]]),
tracking_method="bmadx",
)

_ = tdc3.track(incoming_beam)

assert _.particles.shape[:-2] == torch.Size([3, 2])


def test_transverse_deflecting_cavity_all_parameters_vectorization():
"""
Test that vectorised tracking through a TDC throws now exception and outputs the
correct shape, when all parameters are vectorised.
"""
incoming_beam = cheetah.ParticleBeam.from_parameters(
num_particles=torch.tensor(10_000),
sigma_px=torch.tensor(2e-7),
sigma_py=torch.tensor(2e-7),
energy=torch.tensor([50e6, 60e6]),
)
tdc = cheetah.TransverseDeflectingCavity(
length=torch.tensor(1.0),
voltage=torch.ones([4, 1, 1, 1]) * 1e7,
phase=torch.ones([1, 3, 1, 1]) * 0.4,
frequency=torch.ones([1, 1, 2, 1]) * 1e9,
tracking_method="bmadx",
)

outgoing_beam = tdc.track(incoming_beam)

assert outgoing_beam.particles.shape[:-2] == torch.Size([4, 3, 2, 2])