From 41776233c344621866eb147ab45bb2410243dfae Mon Sep 17 00:00:00 2001 From: Chenran Xu Date: Wed, 16 Oct 2024 15:06:25 +0200 Subject: [PATCH 1/6] Fix vectorisation issue in TDC --- .../transverse_deflecting_cavity.py | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index f15f49c1..2e715d2b 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -151,38 +151,47 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: x_offset, y_offset, self.tilt, x, px, y, py ) + # Broadcast the TDC parameters together + broadcast_shape = torch.broadcast_shapes( + self.length.shape, + self.voltage.shape, + self.phase.shape, + self.frequency.shape, + self.tilt.shape, + ) + + length = self.length.expand(broadcast_shape) + voltage = self.voltage.expand(broadcast_shape) + # Add dimension for macro-particles + frequency = self.frequency.expand(broadcast_shape).unsqueeze(-1) + tdc_phase = self.phase.expand(broadcast_shape).unsqueeze(-1) + x, y, z = bmadx.track_a_drift( - self.length / 2, x, px, y, py, z, pz, p0c, electron_mass_eV + length / 2, x, px, y, py, z, pz, p0c, electron_mass_eV ) - voltage = self.voltage / p0c - k_rf = 2 * torch.pi * self.frequency / speed_of_light + voltage = (voltage / p0c).unsqueeze(-1) + k_rf = 2 * torch.pi * frequency / speed_of_light phase = ( 2 * torch.pi * ( - self.phase - - ( - bmadx.particle_rf_time(z, pz, p0c, electron_mass_eV) - * self.frequency - ) + tdc_phase + - (bmadx.particle_rf_time(z, pz, p0c, electron_mass_eV) * frequency) ) - ) + ) # Phase that the particle sees # TODO: Assigning px to px is really bad practice and should be separated into # two separate variables - px = px + voltage.unsqueeze(-1) * torch.sin(phase) + px = px + voltage * torch.sin(phase) - beta = ( + beta_old = ( (1 + pz) * p0c.unsqueeze(-1) / torch.sqrt(((1 + pz) * p0c.unsqueeze(-1)) ** 2 + electron_mass_eV**2) ) - beta_old = beta E_old = (1 + pz) * p0c.unsqueeze(-1) / beta_old - E_new = E_old + voltage.unsqueeze(-1) * torch.cos( - phase - ) * k_rf * x * p0c.unsqueeze(-1) + E_new = E_old + voltage * torch.cos(phase) * k_rf * x * p0c.unsqueeze(-1) pc = torch.sqrt(E_new**2 - electron_mass_eV**2) beta = pc / E_new @@ -190,7 +199,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: z = z * beta / beta_old x, y, z = bmadx.track_a_drift( - self.length / 2, x, px, y, py, z, pz, p0c, electron_mass_eV + length / 2, x, px, y, py, z, pz, p0c, electron_mass_eV ) x, px, y, py = bmadx.offset_particle_unset( From 9bc164043f6262678b8d250c3eb2c0dd67f7fee6 Mon Sep 17 00:00:00 2001 From: Chenran Xu Date: Wed, 16 Oct 2024 15:12:44 +0200 Subject: [PATCH 2/6] Add vectorised TDC test --- tests/test_transverse_deflecting_cavity.py | 55 ++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/test_transverse_deflecting_cavity.py b/tests/test_transverse_deflecting_cavity.py index eafa8f69..6da37d09 100644 --- a/tests/test_transverse_deflecting_cavity.py +++ b/tests/test_transverse_deflecting_cavity.py @@ -37,3 +37,58 @@ def test_transverse_deflecting_cavity_bmadx_tracking(dtype): atol=1e-14 if dtype == torch.float64 else 0.00001, rtol=1e-14 if dtype == torch.float64 else 1e-6, ) + + +def test_transverse_deflecting_cavity_vectorisation(): + """ + Test that the TDC supports vectroized tracking + """ + incoming_beam = cheetah.ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), + energy=torch.tensor([50e6, 60e6]), + ) + # Vectorise voltage + tdc = cheetah.TransverseDeflectingCavity( + length=torch.tensor(1.0), + voltage=torch.tensor([[1e7], [2e7], [3e7]]), + phase=torch.tensor(0.4), + frequency=torch.tensor(1e9), + tracking_method="bmadx", + ) + + # Run tracking + _ = tdc.track(incoming_beam) + + # Vectorise phase + tdc2 = cheetah.TransverseDeflectingCavity( + length=torch.tensor(1.0), + voltage=torch.tensor(1e7), + phase=torch.tensor([[0.6], [0.5], [0.4]]), + frequency=torch.tensor(1e9), + tracking_method="bmadx", + ) + _ = tdc2.track(incoming_beam) + + # Vectorise frequency + tdc3 = cheetah.TransverseDeflectingCavity( + length=torch.tensor(1.0), + voltage=torch.tensor(1e7), + phase=torch.tensor(0.4), + frequency=torch.tensor([[1e9], [2e9], [3e9]]), + tracking_method="bmadx", + ) + _ = tdc3.track(incoming_beam) + + # Try vectorising all parameters + tdc4 = cheetah.TransverseDeflectingCavity( + length=torch.tensor(1.0), + voltage=torch.ones([4, 1, 1, 1]) * 1e7, + phase=torch.ones([1, 3, 1, 1]) * 0.4, + frequency=torch.ones([1, 1, 2, 1]) * 1e9, + tracking_method="bmadx", + ) + outgoing_beam = tdc4.track(incoming_beam) + + assert outgoing_beam.particles.shape[:-2] == torch.Size([4, 3, 2, 2]) From 8b2545b412fab3041d992a93099663696e2eca76 Mon Sep 17 00:00:00 2001 From: Chenran Xu Date: Wed, 16 Oct 2024 15:15:46 +0200 Subject: [PATCH 3/6] Update `CHANGELOG.md` --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fa71be68..b5e383e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des - Moving `Element`s and `Beam`s to a different `device` and changing their `dtype` like with any `torch.nn.Module` is now possible (see #209) (@jank324) - `Quadrupole` now supports tracking with Cheetah's matrix-based method or with Bmad's more accurate method (see #153) (@jp-ga, @jank324) - Port Bmad-X tracking methods to Cheetah for `Quadrupole`, `Drift`, and `Dipole` (see #153, #240) (@jp-ga, @jank324) -- Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240) (@jp-ga) +- Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240, #278) (@jp-ga, @cr-xu) - `Dipole` and `RBend` now take a focusing moment `k1` (see #235, #247) (@hespe) - Implement a converter for lattice files imported from Elegant (see #222, #251, #273) (@hespe) From 481e2497bb643aed47902312dbcd4e13157042ff Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 16 Oct 2024 21:16:15 +0200 Subject: [PATCH 4/6] Separate tests --- tests/test_transverse_deflecting_cavity.py | 65 ++++++++++++++++++---- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/tests/test_transverse_deflecting_cavity.py b/tests/test_transverse_deflecting_cavity.py index 6da37d09..231c883d 100644 --- a/tests/test_transverse_deflecting_cavity.py +++ b/tests/test_transverse_deflecting_cavity.py @@ -39,9 +39,10 @@ def test_transverse_deflecting_cavity_bmadx_tracking(dtype): ) -def test_transverse_deflecting_cavity_vectorisation(): +def test_transverse_deflecting_cavity_energy_length_vectorization(): """ - Test that the TDC supports vectroized tracking + Test that vectorised tracking through a TDC throws now exception and outputs the + correct shape, when the input beam's energy and the TDC's length are vectorised. """ incoming_beam = cheetah.ParticleBeam.from_parameters( num_particles=torch.tensor(10_000), @@ -49,7 +50,6 @@ def test_transverse_deflecting_cavity_vectorisation(): sigma_py=torch.tensor(2e-7), energy=torch.tensor([50e6, 60e6]), ) - # Vectorise voltage tdc = cheetah.TransverseDeflectingCavity( length=torch.tensor(1.0), voltage=torch.tensor([[1e7], [2e7], [3e7]]), @@ -58,20 +58,46 @@ def test_transverse_deflecting_cavity_vectorisation(): tracking_method="bmadx", ) - # Run tracking - _ = tdc.track(incoming_beam) + outgoing_beam = tdc.track(incoming_beam) + + assert outgoing_beam.particles.shape[:-2] == torch.Size([3, 2]) - # Vectorise phase - tdc2 = cheetah.TransverseDeflectingCavity( + +def test_transverse_deflecting_cavity_energy_phase_vectorization(): + """ + Test that vectorised tracking through a TDC throws now exception and outputs the + correct shape, when the input beam's energy and the TDC's phase are vectorised. + """ + incoming_beam = cheetah.ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), + energy=torch.tensor([50e6, 60e6]), + ) + tdc = cheetah.TransverseDeflectingCavity( length=torch.tensor(1.0), voltage=torch.tensor(1e7), phase=torch.tensor([[0.6], [0.5], [0.4]]), frequency=torch.tensor(1e9), tracking_method="bmadx", ) - _ = tdc2.track(incoming_beam) - # Vectorise frequency + outgoing_beam = tdc.track(incoming_beam) + + assert outgoing_beam.particles.shape[:-2] == torch.Size([3, 2]) + + +def test_transverse_deflecting_cavity_energy_frequency_vectorization(): + """ + Test that vectorised tracking through a TDC throws now exception and outputs the + correct shape, when the input beam's energy and the TDC's frequency are vectorised. + """ + incoming_beam = cheetah.ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), + energy=torch.tensor([50e6, 60e6]), + ) tdc3 = cheetah.TransverseDeflectingCavity( length=torch.tensor(1.0), voltage=torch.tensor(1e7), @@ -79,16 +105,31 @@ def test_transverse_deflecting_cavity_vectorisation(): frequency=torch.tensor([[1e9], [2e9], [3e9]]), tracking_method="bmadx", ) + _ = tdc3.track(incoming_beam) - # Try vectorising all parameters - tdc4 = cheetah.TransverseDeflectingCavity( + assert _.particles.shape[:-2] == torch.Size([3, 2]) + + +def test_transverse_deflecting_cavity_all_parameters_vectorization(): + """ + Test that vectorised tracking through a TDC throws now exception and outputs the + correct shape, when all parameters are vectorised. + """ + incoming_beam = cheetah.ParticleBeam.from_parameters( + num_particles=torch.tensor(10_000), + sigma_px=torch.tensor(2e-7), + sigma_py=torch.tensor(2e-7), + energy=torch.tensor([50e6, 60e6]), + ) + tdc = cheetah.TransverseDeflectingCavity( length=torch.tensor(1.0), voltage=torch.ones([4, 1, 1, 1]) * 1e7, phase=torch.ones([1, 3, 1, 1]) * 0.4, frequency=torch.ones([1, 1, 2, 1]) * 1e9, tracking_method="bmadx", ) - outgoing_beam = tdc4.track(incoming_beam) + + outgoing_beam = tdc.track(incoming_beam) assert outgoing_beam.particles.shape[:-2] == torch.Size([4, 3, 2, 2]) From 5b6ae5bd6bf9416dcff1f24988fac5fc6959176f Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 16 Oct 2024 21:31:45 +0200 Subject: [PATCH 5/6] Switch to style where operations do broadcasting themselves and broadcasting happens only when needed with variables that need it --- .../transverse_deflecting_cavity.py | 39 +++++++------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index 2e715d2b..fd4ed2af 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -151,39 +151,28 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: x_offset, y_offset, self.tilt, x, px, y, py ) - # Broadcast the TDC parameters together - broadcast_shape = torch.broadcast_shapes( - self.length.shape, - self.voltage.shape, - self.phase.shape, - self.frequency.shape, - self.tilt.shape, - ) - - length = self.length.expand(broadcast_shape) - voltage = self.voltage.expand(broadcast_shape) - # Add dimension for macro-particles - frequency = self.frequency.expand(broadcast_shape).unsqueeze(-1) - tdc_phase = self.phase.expand(broadcast_shape).unsqueeze(-1) - x, y, z = bmadx.track_a_drift( - length / 2, x, px, y, py, z, pz, p0c, electron_mass_eV + self.length / 2, x, px, y, py, z, pz, p0c, electron_mass_eV ) - voltage = (voltage / p0c).unsqueeze(-1) - k_rf = 2 * torch.pi * frequency / speed_of_light + voltage = self.voltage / p0c + k_rf = 2 * torch.pi * self.frequency / speed_of_light + # Phase that the particle sees phase = ( 2 * torch.pi * ( - tdc_phase - - (bmadx.particle_rf_time(z, pz, p0c, electron_mass_eV) * frequency) + self.phase.unsqueeze(-1) + - ( + bmadx.particle_rf_time(z, pz, p0c, electron_mass_eV) + * self.frequency.unsqueeze(-1) + ) ) - ) # Phase that the particle sees + ) # TODO: Assigning px to px is really bad practice and should be separated into # two separate variables - px = px + voltage * torch.sin(phase) + px = px + voltage.unsqueeze(-1) * torch.sin(phase) beta_old = ( (1 + pz) @@ -191,7 +180,9 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: / torch.sqrt(((1 + pz) * p0c.unsqueeze(-1)) ** 2 + electron_mass_eV**2) ) E_old = (1 + pz) * p0c.unsqueeze(-1) / beta_old - E_new = E_old + voltage * torch.cos(phase) * k_rf * x * p0c.unsqueeze(-1) + E_new = E_old + voltage.unsqueeze(-1) * torch.cos(phase) * k_rf.unsqueeze( + -1 + ) * x * p0c.unsqueeze(-1) pc = torch.sqrt(E_new**2 - electron_mass_eV**2) beta = pc / E_new @@ -199,7 +190,7 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: z = z * beta / beta_old x, y, z = bmadx.track_a_drift( - length / 2, x, px, y, py, z, pz, p0c, electron_mass_eV + self.length / 2, x, px, y, py, z, pz, p0c, electron_mass_eV ) x, px, y, py = bmadx.offset_particle_unset( From 2f2e46fe97ee053df9c58be7c0754b87ddfdf700 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 16 Oct 2024 21:32:17 +0200 Subject: [PATCH 6/6] Update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b5e383e8..ae4078ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des - Moving `Element`s and `Beam`s to a different `device` and changing their `dtype` like with any `torch.nn.Module` is now possible (see #209) (@jank324) - `Quadrupole` now supports tracking with Cheetah's matrix-based method or with Bmad's more accurate method (see #153) (@jp-ga, @jank324) - Port Bmad-X tracking methods to Cheetah for `Quadrupole`, `Drift`, and `Dipole` (see #153, #240) (@jp-ga, @jank324) -- Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240, #278) (@jp-ga, @cr-xu) +- Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240, #278) (@jp-ga, @cr-xu, @jank324) - `Dipole` and `RBend` now take a focusing moment `k1` (see #235, #247) (@hespe) - Implement a converter for lattice files imported from Elegant (see #222, #251, #273) (@hespe)