diff --git a/CHANGELOG.md b/CHANGELOG.md index fa71be68..ae4078ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des - Moving `Element`s and `Beam`s to a different `device` and changing their `dtype` like with any `torch.nn.Module` is now possible (see #209) (@jank324) - `Quadrupole` now supports tracking with Cheetah's matrix-based method or with Bmad's more accurate method (see #153) (@jp-ga, @jank324) - Port Bmad-X tracking methods to Cheetah for `Quadrupole`, `Drift`, and `Dipole` (see #153, #240) (@jp-ga, @jank324) -- Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240) (@jp-ga) +- Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240, #278) (@jp-ga, @cr-xu, @jank324) - `Dipole` and `RBend` now take a focusing moment `k1` (see #235, #247) (@hespe) - Implement a converter for lattice files imported from Elegant (see #222, #251, #273) (@hespe) diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index f15f49c1..fd4ed2af 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -157,14 +157,15 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: voltage = self.voltage / p0c k_rf = 2 * torch.pi * self.frequency / speed_of_light + # Phase that the particle sees phase = ( 2 * torch.pi * ( - self.phase + self.phase.unsqueeze(-1) - ( bmadx.particle_rf_time(z, pz, p0c, electron_mass_eV) - * self.frequency + * self.frequency.unsqueeze(-1) ) ) ) @@ -173,16 +174,15 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: # two separate variables px = px + voltage.unsqueeze(-1) * torch.sin(phase) - beta = ( + beta_old = ( (1 + pz) * p0c.unsqueeze(-1) / torch.sqrt(((1 + pz) * p0c.unsqueeze(-1)) ** 2 + electron_mass_eV**2) ) - beta_old = beta E_old = (1 + pz) * p0c.unsqueeze(-1) / beta_old - E_new = E_old + voltage.unsqueeze(-1) * torch.cos( - phase - ) * k_rf * x * p0c.unsqueeze(-1) + E_new = E_old + voltage.unsqueeze(-1) * torch.cos(phase) * k_rf.unsqueeze( + -1 + ) * x * p0c.unsqueeze(-1) pc = torch.sqrt(E_new**2 - electron_mass_eV**2) beta = pc / E_new diff --git a/tests/test_transverse_deflecting_cavity.py b/tests/test_transverse_deflecting_cavity.py index eafa8f69..231c883d 100644 --- a/tests/test_transverse_deflecting_cavity.py +++ b/tests/test_transverse_deflecting_cavity.py @@ -37,3 +37,99 @@ def test_transverse_deflecting_cavity_bmadx_tracking(dtype): atol=1e-14 if dtype == torch.float64 else 0.00001, rtol=1e-14 if dtype == torch.float64 else 1e-6, ) + + +def test_transverse_deflecting_cavity_energy_length_vectorization(): + """ + Test that vectorised tracking through a TDC throws now exception and outputs the + correct shape, when the input beam's energy and the TDC's length are vectorised. + """ + incoming_beam = cheetah.ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), + energy=torch.tensor([50e6, 60e6]), + ) + tdc = cheetah.TransverseDeflectingCavity( + length=torch.tensor(1.0), + voltage=torch.tensor([[1e7], [2e7], [3e7]]), + phase=torch.tensor(0.4), + frequency=torch.tensor(1e9), + tracking_method="bmadx", + ) + + outgoing_beam = tdc.track(incoming_beam) + + assert outgoing_beam.particles.shape[:-2] == torch.Size([3, 2]) + + +def test_transverse_deflecting_cavity_energy_phase_vectorization(): + """ + Test that vectorised tracking through a TDC throws now exception and outputs the + correct shape, when the input beam's energy and the TDC's phase are vectorised. + """ + incoming_beam = cheetah.ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), + energy=torch.tensor([50e6, 60e6]), + ) + tdc = cheetah.TransverseDeflectingCavity( + length=torch.tensor(1.0), + voltage=torch.tensor(1e7), + phase=torch.tensor([[0.6], [0.5], [0.4]]), + frequency=torch.tensor(1e9), + tracking_method="bmadx", + ) + + outgoing_beam = tdc.track(incoming_beam) + + assert outgoing_beam.particles.shape[:-2] == torch.Size([3, 2]) + + +def test_transverse_deflecting_cavity_energy_frequency_vectorization(): + """ + Test that vectorised tracking through a TDC throws now exception and outputs the + correct shape, when the input beam's energy and the TDC's frequency are vectorised. + """ + incoming_beam = cheetah.ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), + energy=torch.tensor([50e6, 60e6]), + ) + tdc3 = cheetah.TransverseDeflectingCavity( + length=torch.tensor(1.0), + voltage=torch.tensor(1e7), + phase=torch.tensor(0.4), + frequency=torch.tensor([[1e9], [2e9], [3e9]]), + tracking_method="bmadx", + ) + + _ = tdc3.track(incoming_beam) + + assert _.particles.shape[:-2] == torch.Size([3, 2]) + + +def test_transverse_deflecting_cavity_all_parameters_vectorization(): + """ + Test that vectorised tracking through a TDC throws now exception and outputs the + correct shape, when all parameters are vectorised. + """ + incoming_beam = cheetah.ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), + energy=torch.tensor([50e6, 60e6]), + ) + tdc = cheetah.TransverseDeflectingCavity( + length=torch.tensor(1.0), + voltage=torch.ones([4, 1, 1, 1]) * 1e7, + phase=torch.ones([1, 3, 1, 1]) * 0.4, + frequency=torch.ones([1, 1, 2, 1]) * 1e9, + tracking_method="bmadx", + ) + + outgoing_beam = tdc.track(incoming_beam) + + assert outgoing_beam.particles.shape[:-2] == torch.Size([4, 3, 2, 2])