From 0d2604b1ad25463e82762b0520cd94afff7d652f Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Mon, 28 Oct 2024 09:36:53 +0100 Subject: [PATCH 01/31] Add cloning method for beams --- cheetah/accelerator/bpm.py | 3 +-- cheetah/particles/beam.py | 4 ++++ cheetah/particles/parameter_beam.py | 9 +++++++++ cheetah/particles/particle_beam.py | 8 ++++++++ 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/cheetah/accelerator/bpm.py b/cheetah/accelerator/bpm.py index d0e636bd..14ec9130 100644 --- a/cheetah/accelerator/bpm.py +++ b/cheetah/accelerator/bpm.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import Optional import matplotlib.pyplot as plt @@ -46,7 +45,7 @@ def track(self, incoming: Beam) -> Beam: else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") - return deepcopy(incoming) + return incoming.clone() def split(self, resolution: torch.Tensor) -> list[Element]: return [self] diff --git a/cheetah/particles/beam.py b/cheetah/particles/beam.py index 427e0329..1f6fd7b5 100644 --- a/cheetah/particles/beam.py +++ b/cheetah/particles/beam.py @@ -357,6 +357,10 @@ def alpha_y(self) -> torch.Tensor: """Alpha function in y direction, dimensionless.""" return -self.sigma_ypy / self.emittance_y + def clone(self) -> "Beam": + """Return a cloned beam that does not share the underlying storage.""" + raise NotImplementedError + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(mu_x={self.mu_x}, mu_px={self.mu_px}," diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index dcdbb8c0..df311ed7 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -423,6 +423,15 @@ def sigma_xpx(self) -> torch.Tensor: def sigma_ypy(self) -> torch.Tensor: return self._cov[..., 2, 3] + def clone(self) -> "ParameterBeam": + """Return a cloned beam that does not share the underlying storage.""" + return ParameterBeam( + mu=self._mu.clone(), + cov=self._cov.clone(), + total_charge=self.total_charge.clone(), + energy=self.energy.clone(), + ) + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(mu_x={repr(self.mu_x)}," diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 5b83bab1..d6bb4a46 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -902,6 +902,14 @@ def momenta(self) -> torch.Tensor: """Momenta of the individual particles.""" return torch.sqrt(self.energies**2 - electron_mass_eV**2) + def clone(self) -> "ParticleBeam": + """Return a cloned beam that does not share the underlying storage.""" + return ParticleBeam( + particles=self.particles.clone(), + particle_charges=self.particle_charges.clone(), + energy=self.energy.clone(), + ) + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(n={repr(self.num_particles)}," From 8a883d8d4f490b9e2a8a8daf891385623cb3ef4b Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Mon, 28 Oct 2024 11:59:47 +0100 Subject: [PATCH 02/31] Remove unnecessary comments --- cheetah/particles/parameter_beam.py | 1 - cheetah/particles/particle_beam.py | 1 - 2 files changed, 2 deletions(-) diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index df311ed7..fb61f70b 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -424,7 +424,6 @@ def sigma_ypy(self) -> torch.Tensor: return self._cov[..., 2, 3] def clone(self) -> "ParameterBeam": - """Return a cloned beam that does not share the underlying storage.""" return ParameterBeam( mu=self._mu.clone(), cov=self._cov.clone(), diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index d6bb4a46..bf46e2fb 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -903,7 +903,6 @@ def momenta(self) -> torch.Tensor: return torch.sqrt(self.energies**2 - electron_mass_eV**2) def clone(self) -> "ParticleBeam": - """Return a cloned beam that does not share the underlying storage.""" return ParticleBeam( particles=self.particles.clone(), particle_charges=self.particle_charges.clone(), From 3e99a57adbfc45851dd735a4bc99c8cbdb02b5d1 Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Mon, 28 Oct 2024 12:10:15 +0100 Subject: [PATCH 03/31] Add clone method for elements --- cheetah/accelerator/aperture.py | 9 +++++++++ cheetah/accelerator/bpm.py | 3 +++ cheetah/accelerator/cavity.py | 9 +++++++++ cheetah/accelerator/custom_transfer_map.py | 7 +++++++ cheetah/accelerator/dipole.py | 18 ++++++++++++++++++ cheetah/accelerator/drift.py | 7 +++++++ cheetah/accelerator/element.py | 5 +++++ cheetah/accelerator/horizontal_corrector.py | 5 +++++ cheetah/accelerator/marker.py | 3 +++ cheetah/accelerator/quadrupole.py | 11 +++++++++++ cheetah/accelerator/screen.py | 13 +++++++++++++ cheetah/accelerator/segment.py | 5 +++++ cheetah/accelerator/solenoid.py | 8 ++++++++ cheetah/accelerator/space_charge_kick.py | 12 ++++++++++++ .../transverse_deflecting_cavity.py | 13 +++++++++++++ cheetah/accelerator/undulator.py | 7 +++++++ cheetah/accelerator/vertical_corrector.py | 5 +++++ 17 files changed, 140 insertions(+) diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index cdf5d256..af6c847e 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -109,6 +109,15 @@ def track(self, incoming: Beam) -> Beam: else ParticleBeam.empty ) + def clone(self) -> "Aperture": + return Aperture( + x_max=self.x_max.clone(), + y_max=self.y_max.clone(), + shape=self.shape, + is_active=self.is_active, + name=self.name, + ) + def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for aperture properly, for now just return self return [self] diff --git a/cheetah/accelerator/bpm.py b/cheetah/accelerator/bpm.py index 14ec9130..73efdba8 100644 --- a/cheetah/accelerator/bpm.py +++ b/cheetah/accelerator/bpm.py @@ -47,6 +47,9 @@ def track(self, incoming: Beam) -> Beam: return incoming.clone() + def clone(self) -> "BPM": + return BPM(is_active=self.is_active, name=self.name) + def split(self, resolution: torch.Tensor) -> list[Element]: return [self] diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index c7a89e05..8e4ec1b3 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -336,6 +336,15 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: return R + def clone(self) -> "Cavity": + return Cavity( + length=self.length.clone(), + voltage=self.voltage.clone(), + phase=self.phase.clone(), + frequency=self.frequency.clone(), + name=self.name, + ) + def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for cavity properly, for now just returns the # element itself diff --git a/cheetah/accelerator/custom_transfer_map.py b/cheetah/accelerator/custom_transfer_map.py index 5baf87d2..32ba51c4 100644 --- a/cheetah/accelerator/custom_transfer_map.py +++ b/cheetah/accelerator/custom_transfer_map.py @@ -99,6 +99,13 @@ def __repr__(self): def defining_features(self) -> list[str]: return super().defining_features + ["transfer_map"] + def clone(self) -> "CustomTransferMap": + return CustomTransferMap( + transfer_map=self._transfer_map.clone(), + length=self.length.clone(), + name=self.name, + ) + def split(self, resolution: torch.Tensor) -> list[Element]: return [self] diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index 5e919fbf..28cfef29 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -439,6 +439,24 @@ def _transfer_map_exit(self) -> torch.Tensor: return tm + def clone(self) -> "Dipole": + return Dipole( + length=self.length.clone(), + angle=self.angle.clone(), + k1=self.k1.clone(), + e1=self.e1.clone(), + e2=self.e2.clone(), + tilt=self.tilt.clone(), + gap=self.gap.clone(), + gap_exit=self.gap_exit.clone(), + fringe_integral=self.fringe_integral.clone(), + fringe_integral_exit=self.fringe_integral_exit.clone(), + fringe_at=self.fringe_at, + fringe_type=self.fringe_type, + tracking_method=self.tracking_method, + name=self.name, + ) + def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for dipole properly, for now just returns the # element itself diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index eb5fb187..627da035 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -124,6 +124,13 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: def is_skippable(self) -> bool: return self.tracking_method == "cheetah" + def clone(self) -> "Drift": + return Drift( + length=self.length.clone(), + tracking_method=self.tracking_method, + name=self.name, + ) + def split(self, resolution: torch.Tensor) -> list[Element]: num_splits = torch.ceil(torch.max(self.length) / resolution).int() return [ diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index bfe1df7c..1baa328f 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -113,6 +113,11 @@ def defining_features(self) -> list[str]: """ return [] + @abstractmethod + def clone(self) -> "Element": + """Create a copy of the element which does not share the underlying memory.""" + raise NotImplementedError + @abstractmethod def split(self, resolution: torch.Tensor) -> list["Element"]: """ diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index a00c2cbe..dd423a18 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -70,6 +70,11 @@ def is_skippable(self) -> bool: def is_active(self) -> bool: return torch.any(self.angle != 0) + def clone(self) -> "HorizontalCorrector": + return HorizontalCorrector( + length=self.length.clone(), angle=self.angle.clone(), name=self.name + ) + def split(self, resolution: torch.Tensor) -> list[Element]: num_splits = torch.ceil(torch.max(self.length) / resolution).int() return [ diff --git a/cheetah/accelerator/marker.py b/cheetah/accelerator/marker.py index 37daa593..b0829a45 100644 --- a/cheetah/accelerator/marker.py +++ b/cheetah/accelerator/marker.py @@ -34,6 +34,9 @@ def track(self, incoming: Beam) -> Beam: def is_skippable(self) -> bool: return True + def clone(self) -> "Marker": + return Marker(name=self.name) + def split(self, resolution: torch.Tensor) -> list[Element]: return [self] diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index db6a559d..f5f9c605 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -201,6 +201,17 @@ def is_skippable(self) -> bool: def is_active(self) -> bool: return torch.any(self.k1 != 0) + def clone(self) -> "Quadrupole": + return Quadrupole( + length=self.length.clone(), + k1=self.k1.clone(), + misalignment=self.misalignment.clone(), + tilt=self.tilt.clone(), + num_steps=self.num_steps, + tracking_method=self.tracking_method, + name=self.name, + ) + def split(self, resolution: torch.Tensor) -> list[Element]: num_splits = torch.ceil(torch.max(self.length) / resolution).int() return [ diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index d7b153d6..8a8836e9 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -294,6 +294,19 @@ def set_read_beam(self, value: Beam) -> None: self._read_beam = value self.cached_reading = None + def clone(self) -> "Screen": + return Screen( + resolution=self.resolution.clone(), + pixel_size=self.pixel_size.clone(), + binning=self.binning.clone(), + misalignment=self.misalignment.clone(), + method=self.method, + kde_bandwidth=self.kde_bandwidth.clone(), + is_blocking=self.is_blocking, + is_active=self.is_active, + name=self.name, + ) + def split(self, resolution: torch.Tensor) -> list[Element]: return [self] diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index 290195e8..b1a20b5d 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -379,6 +379,11 @@ def track(self, incoming: Beam) -> Beam: return incoming + def clone(self) -> "Segment": + return Segment( + elements=[element.clone() for element in self.elements], name=self.name + ) + def split(self, resolution: torch.Tensor) -> list[Element]: return [ split_element diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index f8faf24c..5d4bc0fb 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -112,6 +112,14 @@ def is_active(self) -> bool: def is_skippable(self) -> bool: return True + def clone(self) -> "Solenoid": + return Solenoid( + length=self.length.clone(), + k=self.k.clone(), + misalignment=self.misalignment.clone(), + name=self.name, + ) + def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for solenoid properly, for now just return self return [self] diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 7bb7d7f3..80825abc 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -633,6 +633,18 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") + def clone(self) -> "SpaceChargeKick": + return SpaceChargeKick( + effect_length=self.effect_length.clone(), + num_grid_points_x=self.grid_shape[0], + num_grid_points_y=self.grid_shape[1], + num_grid_points_tau=self.grid_shape[2], + grid_extend_x=self.grid_extend_x.clone(), + grid_extend_y=self.grid_extend_y.clone(), + grid_extend_tau=self.grid_extend_tau.clone(), + name=self.name, + ) + def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for SpaceCharge properly, for now just returns the # element itself diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index fd4ed2af..ae6d4e29 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -212,6 +212,19 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) return outgoing_beam + def clone(self) -> "TransverseDeflectingCavity": + return TransverseDeflectingCavity( + length=self.length.clone(), + voltage=self.voltage.clone(), + phase=self.phase.clone(), + frequency=self.frequency.clone(), + misalignment=self.misalignment.clone(), + tilt=self.tilt.clone(), + num_steps=self.num_steps, + tracking_method=self.tracking_method, + name=self.name, + ) + def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for cavity properly, for now just returns the # element itself diff --git a/cheetah/accelerator/undulator.py b/cheetah/accelerator/undulator.py index e7304870..ace18d2c 100644 --- a/cheetah/accelerator/undulator.py +++ b/cheetah/accelerator/undulator.py @@ -60,6 +60,13 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: def is_skippable(self) -> bool: return True + def clone(self) -> "Undulator": + return Undulator( + length=self.length.clone(), + is_active=self.is_active, + name=self.name, + ) + def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for undulator properly, for now just return self return [self] diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index be5ba4e4..59b197ab 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -73,6 +73,11 @@ def is_skippable(self) -> bool: def is_active(self) -> bool: return torch.any(self.angle != 0) + def clone(self) -> "VerticalCorrector": + return VerticalCorrector( + length=self.length.clone(), angle=self.angle.clone(), name=self.name + ) + def split(self, resolution: torch.Tensor) -> list[Element]: num_splits = torch.ceil(torch.max(self.length) / resolution).int() return [ From e788b528f769df7ba24b3667801fb80c4cf41b95 Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Mon, 28 Oct 2024 14:38:56 +0100 Subject: [PATCH 04/31] Remove deepcopy from plotting method --- cheetah/accelerator/segment.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cheetah/accelerator/segment.py b/cheetah/accelerator/segment.py index b1a20b5d..2ecdea84 100644 --- a/cheetah/accelerator/segment.py +++ b/cheetah/accelerator/segment.py @@ -1,4 +1,3 @@ -from copy import deepcopy from functools import reduce from pathlib import Path from typing import Any, Optional, Union @@ -440,7 +439,7 @@ def plot_reference_particle_traces( case of present vector dimension but no index provided, the first one is used by default. """ - reference_segment = deepcopy(self) + reference_segment = self.clone() splits = reference_segment.split(resolution=torch.tensor(resolution)) split_lengths = [split.length for split in splits] From 27545364d9951f01b08f998bc1bfb0e8ab5286a6 Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Mon, 28 Oct 2024 14:40:35 +0100 Subject: [PATCH 05/31] Test clone implementations --- tests/test_clone.py | 50 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 tests/test_clone.py diff --git a/tests/test_clone.py b/tests/test_clone.py new file mode 100644 index 00000000..29f51962 --- /dev/null +++ b/tests/test_clone.py @@ -0,0 +1,50 @@ +import pytest +import torch + +import cheetah + + +@pytest.mark.parametrize( + "ElementClass", + [ + cheetah.Cavity, + cheetah.Dipole, + cheetah.Drift, + cheetah.HorizontalCorrector, + cheetah.Quadrupole, + cheetah.RBend, + cheetah.Solenoid, + cheetah.TransverseDeflectingCavity, + cheetah.Undulator, + cheetah.VerticalCorrector, + ], +) +def test_element_buffer_location(ElementClass): + """ + Test that the buffers of cloned elements do not share memory. + """ + element = ElementClass(length=torch.tensor(1.0)) + clone = element.clone() + + for buffer, buffer_clone in zip(element.buffers(), clone.buffers()): + assert torch.allclose(buffer, buffer_clone) + assert not buffer.data_ptr() == buffer_clone.data_ptr() + + +@pytest.mark.parametrize( + "BeamClass", + [ + cheetah.ParameterBeam, + cheetah.ParticleBeam, + ], +) +def test_beam_buffer_location(BeamClass): + """ + Test that the buffers of clones beams do not share memory. + """ + beam = BeamClass.from_parameters() + clone = beam.clone() + + for buffer, buffer_clone in zip(beam.buffers(), clone.buffers()): + assert torch.allclose(buffer, buffer_clone) + assert not buffer.data_ptr() == buffer_clone.data_ptr() From 404f4c934c7118ac734fdcd6c42233928c855cf4 Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Mon, 28 Oct 2024 15:31:53 +0100 Subject: [PATCH 06/31] Add test for tracking with nonleave tensors --- tests/test_tracking_nonleave_tensors.py | 63 +++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 tests/test_tracking_nonleave_tensors.py diff --git a/tests/test_tracking_nonleave_tensors.py b/tests/test_tracking_nonleave_tensors.py new file mode 100644 index 00000000..dc5f9fb2 --- /dev/null +++ b/tests/test_tracking_nonleave_tensors.py @@ -0,0 +1,63 @@ +import pytest +import torch + +import cheetah + + +@pytest.mark.parametrize( + "ElementClass", + [ + cheetah.Cavity, + cheetah.Dipole, + cheetah.Drift, + cheetah.HorizontalCorrector, + cheetah.Quadrupole, + cheetah.RBend, + cheetah.Solenoid, + cheetah.TransverseDeflectingCavity, + cheetah.Undulator, + cheetah.VerticalCorrector, + ], +) +def test_nonleave_tracking(ElementClass): + """ + Test that a beam with non-leave tensors as elements can be tracked through + elements with length parameter. + """ + beam = cheetah.ParticleBeam.from_parameters() + + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(1.0, requires_grad=True)), + ElementClass(length=torch.tensor(2.0)), + ] + ) + segment.track(beam) + + +@pytest.mark.parametrize( + "ElementClass", + [ + cheetah.Aperture, + cheetah.BPM, + cheetah.Screen, + ], +) +def test_nonleave_lenghtless_elements(ElementClass): + """ + Test that a beam with non-leave tensors as elements can be tracked through + elements without length parameter. + + The split into lengthless elements is necessary since there is no common + constructor for all element classes. Some require a length, some cannot + handle a length argument. + """ + beam = cheetah.ParticleBeam.from_parameters() + + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(1.0, requires_grad=True)), + ElementClass(is_active=True), + ] + ) + segment.track(beam) From 94c52df56de973e1407e005ad1cc1a45c0bfaecb Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Mon, 28 Oct 2024 15:33:28 +0100 Subject: [PATCH 07/31] Remove deepcopy from Screen --- cheetah/accelerator/screen.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 8a8836e9..5fd60301 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import Literal, Optional, Union import matplotlib.pyplot as plt @@ -166,7 +165,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: def track(self, incoming: Beam) -> Beam: if self.is_active: - copy_of_incoming = deepcopy(incoming) + copy_of_incoming = incoming.clone() if isinstance(incoming, ParameterBeam): copy_of_incoming._mu, _ = torch.broadcast_tensors( From 9aca0cdca29f427810d77854eca755334422d9db Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Mon, 28 Oct 2024 16:34:17 +0100 Subject: [PATCH 08/31] Test plotting with nonleave tensors --- tests/test_plotting.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index c1b3bfe7..dd94b574 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -103,3 +103,26 @@ def test_reference_particle_plot_vectorized_2d(): # Run the plotting to see if it raises an exception segment.plot_overview(incoming=incoming, resolution=0.1, vector_idx=(0, 2)) + + +def test_plotting_with_nonleave_tensors(): + """ + Test that the plotting routines can handle elements with non-leave tensors. + """ + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(1.0, requires_grad=True)), + cheetah.BPM(is_active=True), + ] + ) + + incoming = cheetah.ParticleBeam.from_astra( + "tests/resources/ACHIP_EA1_2021.1351.001" + ) + + # Prepopulate the segment + segment.track(incoming) + + # Test that plotting does not raise an exception + segment.plot_overview(incoming=incoming) + segment.plot_twiss(incoming=incoming) From 936a1d68646c481c01438e196792e6afd2c83105 Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Thu, 21 Nov 2024 08:52:19 +0100 Subject: [PATCH 09/31] Fix formatting --- tests/test_plotting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 32d19287..6968d160 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -127,7 +127,7 @@ def test_plotting_with_nonleave_tensors(): segment.plot_overview(incoming=incoming) segment.plot_twiss(incoming=incoming) - + def test_plotting_with_gradients(): """ Test that plotting doesn't raise an exception for segments that contain tensors From 9fdab40bc43970a6f28e6a939ed247bc6851d36b Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Thu, 21 Nov 2024 08:53:37 +0100 Subject: [PATCH 10/31] Mark Beam.clone as abstract --- cheetah/particles/beam.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cheetah/particles/beam.py b/cheetah/particles/beam.py index ae8bc262..553eda1a 100644 --- a/cheetah/particles/beam.py +++ b/cheetah/particles/beam.py @@ -391,6 +391,7 @@ def alpha_y(self) -> torch.Tensor: """Alpha function in y direction, dimensionless.""" return -self.sigma_ypy / self.emittance_y + @abstractmethod def clone(self) -> "Beam": """Return a cloned beam that does not share the underlying storage.""" raise NotImplementedError From b9e22d2f3dde7a5744d34fd00e66ef2ee7eb2a0e Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Fri, 22 Nov 2024 17:04:00 +0100 Subject: [PATCH 11/31] Write generic clone method for elements --- cheetah/accelerator/aperture.py | 9 --------- cheetah/accelerator/bpm.py | 3 --- cheetah/accelerator/cavity.py | 9 --------- cheetah/accelerator/custom_transfer_map.py | 7 ------- cheetah/accelerator/dipole.py | 18 ------------------ cheetah/accelerator/drift.py | 7 ------- cheetah/accelerator/element.py | 13 +++++++++++-- cheetah/accelerator/horizontal_corrector.py | 5 ----- cheetah/accelerator/marker.py | 3 --- cheetah/accelerator/quadrupole.py | 11 ----------- cheetah/accelerator/screen.py | 13 ------------- cheetah/accelerator/solenoid.py | 8 -------- cheetah/accelerator/space_charge_kick.py | 12 ------------ .../transverse_deflecting_cavity.py | 13 ------------- cheetah/accelerator/undulator.py | 7 ------- cheetah/accelerator/vertical_corrector.py | 5 ----- 16 files changed, 11 insertions(+), 132 deletions(-) diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index af6c847e..cdf5d256 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -109,15 +109,6 @@ def track(self, incoming: Beam) -> Beam: else ParticleBeam.empty ) - def clone(self) -> "Aperture": - return Aperture( - x_max=self.x_max.clone(), - y_max=self.y_max.clone(), - shape=self.shape, - is_active=self.is_active, - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for aperture properly, for now just return self return [self] diff --git a/cheetah/accelerator/bpm.py b/cheetah/accelerator/bpm.py index 73efdba8..14ec9130 100644 --- a/cheetah/accelerator/bpm.py +++ b/cheetah/accelerator/bpm.py @@ -47,9 +47,6 @@ def track(self, incoming: Beam) -> Beam: return incoming.clone() - def clone(self) -> "BPM": - return BPM(is_active=self.is_active, name=self.name) - def split(self, resolution: torch.Tensor) -> list[Element]: return [self] diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index 8e4ec1b3..c7a89e05 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -336,15 +336,6 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: return R - def clone(self) -> "Cavity": - return Cavity( - length=self.length.clone(), - voltage=self.voltage.clone(), - phase=self.phase.clone(), - frequency=self.frequency.clone(), - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for cavity properly, for now just returns the # element itself diff --git a/cheetah/accelerator/custom_transfer_map.py b/cheetah/accelerator/custom_transfer_map.py index 32ba51c4..5baf87d2 100644 --- a/cheetah/accelerator/custom_transfer_map.py +++ b/cheetah/accelerator/custom_transfer_map.py @@ -99,13 +99,6 @@ def __repr__(self): def defining_features(self) -> list[str]: return super().defining_features + ["transfer_map"] - def clone(self) -> "CustomTransferMap": - return CustomTransferMap( - transfer_map=self._transfer_map.clone(), - length=self.length.clone(), - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: return [self] diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index 28cfef29..5e919fbf 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -439,24 +439,6 @@ def _transfer_map_exit(self) -> torch.Tensor: return tm - def clone(self) -> "Dipole": - return Dipole( - length=self.length.clone(), - angle=self.angle.clone(), - k1=self.k1.clone(), - e1=self.e1.clone(), - e2=self.e2.clone(), - tilt=self.tilt.clone(), - gap=self.gap.clone(), - gap_exit=self.gap_exit.clone(), - fringe_integral=self.fringe_integral.clone(), - fringe_integral_exit=self.fringe_integral_exit.clone(), - fringe_at=self.fringe_at, - fringe_type=self.fringe_type, - tracking_method=self.tracking_method, - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for dipole properly, for now just returns the # element itself diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 627da035..eb5fb187 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -124,13 +124,6 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: def is_skippable(self) -> bool: return self.tracking_method == "cheetah" - def clone(self) -> "Drift": - return Drift( - length=self.length.clone(), - tracking_method=self.tracking_method, - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: num_splits = torch.ceil(torch.max(self.length) / resolution).int() return [ diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index 1baa328f..b1ab1798 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from copy import deepcopy from typing import Optional import matplotlib.pyplot as plt @@ -113,10 +114,18 @@ def defining_features(self) -> list[str]: """ return [] - @abstractmethod def clone(self) -> "Element": """Create a copy of the element which does not share the underlying memory.""" - raise NotImplementedError + return self.__class__( + **{ + feature: ( + getattr(self, feature).clone() + if isinstance(getattr(self, feature), torch.Tensor) + else deepcopy(getattr(self, feature)) + ) + for feature in self.defining_features + } + ) @abstractmethod def split(self, resolution: torch.Tensor) -> list["Element"]: diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index dd423a18..a00c2cbe 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -70,11 +70,6 @@ def is_skippable(self) -> bool: def is_active(self) -> bool: return torch.any(self.angle != 0) - def clone(self) -> "HorizontalCorrector": - return HorizontalCorrector( - length=self.length.clone(), angle=self.angle.clone(), name=self.name - ) - def split(self, resolution: torch.Tensor) -> list[Element]: num_splits = torch.ceil(torch.max(self.length) / resolution).int() return [ diff --git a/cheetah/accelerator/marker.py b/cheetah/accelerator/marker.py index b0829a45..37daa593 100644 --- a/cheetah/accelerator/marker.py +++ b/cheetah/accelerator/marker.py @@ -34,9 +34,6 @@ def track(self, incoming: Beam) -> Beam: def is_skippable(self) -> bool: return True - def clone(self) -> "Marker": - return Marker(name=self.name) - def split(self, resolution: torch.Tensor) -> list[Element]: return [self] diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index f5f9c605..db6a559d 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -201,17 +201,6 @@ def is_skippable(self) -> bool: def is_active(self) -> bool: return torch.any(self.k1 != 0) - def clone(self) -> "Quadrupole": - return Quadrupole( - length=self.length.clone(), - k1=self.k1.clone(), - misalignment=self.misalignment.clone(), - tilt=self.tilt.clone(), - num_steps=self.num_steps, - tracking_method=self.tracking_method, - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: num_splits = torch.ceil(torch.max(self.length) / resolution).int() return [ diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index 5fd60301..8c43e020 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -293,19 +293,6 @@ def set_read_beam(self, value: Beam) -> None: self._read_beam = value self.cached_reading = None - def clone(self) -> "Screen": - return Screen( - resolution=self.resolution.clone(), - pixel_size=self.pixel_size.clone(), - binning=self.binning.clone(), - misalignment=self.misalignment.clone(), - method=self.method, - kde_bandwidth=self.kde_bandwidth.clone(), - is_blocking=self.is_blocking, - is_active=self.is_active, - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: return [self] diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index 5d4bc0fb..f8faf24c 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -112,14 +112,6 @@ def is_active(self) -> bool: def is_skippable(self) -> bool: return True - def clone(self) -> "Solenoid": - return Solenoid( - length=self.length.clone(), - k=self.k.clone(), - misalignment=self.misalignment.clone(), - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for solenoid properly, for now just return self return [self] diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 80825abc..7bb7d7f3 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -633,18 +633,6 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") - def clone(self) -> "SpaceChargeKick": - return SpaceChargeKick( - effect_length=self.effect_length.clone(), - num_grid_points_x=self.grid_shape[0], - num_grid_points_y=self.grid_shape[1], - num_grid_points_tau=self.grid_shape[2], - grid_extend_x=self.grid_extend_x.clone(), - grid_extend_y=self.grid_extend_y.clone(), - grid_extend_tau=self.grid_extend_tau.clone(), - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for SpaceCharge properly, for now just returns the # element itself diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index 7c406e66..e4e9a3cb 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -213,19 +213,6 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: ) return outgoing_beam - def clone(self) -> "TransverseDeflectingCavity": - return TransverseDeflectingCavity( - length=self.length.clone(), - voltage=self.voltage.clone(), - phase=self.phase.clone(), - frequency=self.frequency.clone(), - misalignment=self.misalignment.clone(), - tilt=self.tilt.clone(), - num_steps=self.num_steps, - tracking_method=self.tracking_method, - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for cavity properly, for now just returns the # element itself diff --git a/cheetah/accelerator/undulator.py b/cheetah/accelerator/undulator.py index ace18d2c..e7304870 100644 --- a/cheetah/accelerator/undulator.py +++ b/cheetah/accelerator/undulator.py @@ -60,13 +60,6 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: def is_skippable(self) -> bool: return True - def clone(self) -> "Undulator": - return Undulator( - length=self.length.clone(), - is_active=self.is_active, - name=self.name, - ) - def split(self, resolution: torch.Tensor) -> list[Element]: # TODO: Implement splitting for undulator properly, for now just return self return [self] diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index 59b197ab..be5ba4e4 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -73,11 +73,6 @@ def is_skippable(self) -> bool: def is_active(self) -> bool: return torch.any(self.angle != 0) - def clone(self) -> "VerticalCorrector": - return VerticalCorrector( - length=self.length.clone(), angle=self.angle.clone(), name=self.name - ) - def split(self, resolution: torch.Tensor) -> list[Element]: num_splits = torch.ceil(torch.max(self.length) / resolution).int() return [ From 980f42b063f2bdac9ed300d6bf0e8410252012bb Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Fri, 22 Nov 2024 17:04:26 +0100 Subject: [PATCH 12/31] Add missing argument to RBend constructor --- cheetah/accelerator/rbend.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cheetah/accelerator/rbend.py b/cheetah/accelerator/rbend.py index b50ef08e..04ba7633 100644 --- a/cheetah/accelerator/rbend.py +++ b/cheetah/accelerator/rbend.py @@ -23,6 +23,9 @@ class RBend(Dipole): :param fringe_integral_exit: (only set if different from `fint`) Fringe field integral of the exit face. :param gap: The magnet gap [m], NOTE in MAD and ELEGANT: HGAP = gap/2 + :param gap_exit: The magnet gap at the exit in meters. Note that in MAD and + ELEGANT: HGAP = gap/2. Only set if different from `gap`. Only used with + `"bmadx"` tracking method. :param name: Unique identifier of the element. """ @@ -37,6 +40,7 @@ def __init__( fringe_integral: Optional[Union[torch.Tensor, nn.Parameter]] = None, fringe_integral_exit: Optional[Union[torch.Tensor, nn.Parameter]] = None, gap: Optional[Union[torch.Tensor, nn.Parameter]] = None, + gap_exit: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, device=None, dtype=torch.float32, @@ -51,6 +55,7 @@ def __init__( fringe_integral=fringe_integral, fringe_integral_exit=fringe_integral_exit, gap=gap, + gap_exit=gap_exit, name=name, device=device, dtype=dtype, From 4572bb7bfa0aff8dd4811673c11b8578b20abc8f Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Mon, 25 Nov 2024 11:43:02 +0100 Subject: [PATCH 13/31] Add missing is_active feature for BPM --- cheetah/accelerator/bpm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cheetah/accelerator/bpm.py b/cheetah/accelerator/bpm.py index 14ec9130..ba5d796d 100644 --- a/cheetah/accelerator/bpm.py +++ b/cheetah/accelerator/bpm.py @@ -61,7 +61,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No @property def defining_features(self) -> list[str]: - return super().defining_features + return super().defining_features + ["is_active"] def __repr__(self) -> str: return f"{self.__class__.__name__}(name={repr(self.name)})" From cfd0ef9a5e674cbfaa316b078e1257b27bd2530b Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Mon, 25 Nov 2024 11:47:43 +0100 Subject: [PATCH 14/31] Add clone override for CustomTransferMap --- cheetah/accelerator/custom_transfer_map.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cheetah/accelerator/custom_transfer_map.py b/cheetah/accelerator/custom_transfer_map.py index 5baf87d2..5c782094 100644 --- a/cheetah/accelerator/custom_transfer_map.py +++ b/cheetah/accelerator/custom_transfer_map.py @@ -96,12 +96,20 @@ def __repr__(self): + f"name={repr(self.name)})" ) + @property def defining_features(self) -> list[str]: return super().defining_features + ["transfer_map"] def split(self, resolution: torch.Tensor) -> list[Element]: return [self] + def clone(self) -> "CustomTransferMap": + return CustomTransferMap( + transfer_map=self._transfer_map.clone(), + length=self.length.clone(), + name=self.name, + ) + def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None: plot_s = s[vector_idx] if s.dim() > 0 else s plot_length = self.length[vector_idx] if self.length.dim() > 0 else self.length From bc914daec64a02a1c58dbb3bb152c5a9d92f9c3e Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 25 Nov 2024 11:49:36 +0100 Subject: [PATCH 15/31] Apply suggestions from code review --- cheetah/particles/beam.py | 2 +- cheetah/particles/parameter_beam.py | 2 +- cheetah/particles/particle_beam.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cheetah/particles/beam.py b/cheetah/particles/beam.py index 553eda1a..bcea43c3 100644 --- a/cheetah/particles/beam.py +++ b/cheetah/particles/beam.py @@ -393,7 +393,7 @@ def alpha_y(self) -> torch.Tensor: @abstractmethod def clone(self) -> "Beam": - """Return a cloned beam that does not share the underlying storage.""" + """Return a cloned beam that does not share the underlying memory.""" raise NotImplementedError def __repr__(self) -> str: diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index 49a1f965..0ddd3809 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -430,8 +430,8 @@ def clone(self) -> "ParameterBeam": return ParameterBeam( mu=self._mu.clone(), cov=self._cov.clone(), - total_charge=self.total_charge.clone(), energy=self.energy.clone(), + total_charge=self.total_charge.clone(), ) def __repr__(self) -> str: diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 1bccc931..31d58e45 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -899,8 +899,8 @@ def momenta(self) -> torch.Tensor: def clone(self) -> "ParticleBeam": return ParticleBeam( particles=self.particles.clone(), - particle_charges=self.particle_charges.clone(), energy=self.energy.clone(), + particle_charges=self.particle_charges.clone(), ) def __repr__(self) -> str: From bdeb9616f9665ce6c243dd23c9fb9819a129d95b Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Mon, 25 Nov 2024 11:59:07 +0100 Subject: [PATCH 16/31] Add name as a defining feature --- cheetah/accelerator/element.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cheetah/accelerator/element.py b/cheetah/accelerator/element.py index b1ab1798..2ad7c09e 100644 --- a/cheetah/accelerator/element.py +++ b/cheetah/accelerator/element.py @@ -112,7 +112,7 @@ def defining_features(self) -> list[str]: NOTE: When overriding this property, make sure to call the super method and extend the list it returns. """ - return [] + return ["name"] def clone(self) -> "Element": """Create a copy of the element which does not share the underlying memory.""" From 94a204be8415e43fc3522c470b04f033d1ac8f44 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 4 Dec 2024 15:16:31 +0100 Subject: [PATCH 17/31] Fix issue where `name` was considered twice with LatticeJSON saving --- cheetah/latticejson.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cheetah/latticejson.py b/cheetah/latticejson.py index ffd0af44..4f83edd3 100644 --- a/cheetah/latticejson.py +++ b/cheetah/latticejson.py @@ -33,6 +33,7 @@ def convert_element(element: "cheetah.Element"): params = { feauture: feature2nontorch(getattr(element, feauture)) for feauture in element.defining_features + if feauture != "name" } return element.name, element.__class__.__name__, params From ba6ff8407c2ee366e8cf2a3ee990fc4851e57b8b Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 4 Dec 2024 15:20:20 +0100 Subject: [PATCH 18/31] Make sure `is_active` is part of `defining_features` when it can be set manually --- cheetah/accelerator/aperture.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index 6999efdd..cfe5c5b1 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -128,12 +128,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No @property def defining_features(self) -> list[str]: - return super().defining_features + [ - "x_max", - "y_max", - "shape", - "is_active", - ] + return super().defining_features + ["x_max", "y_max", "shape", "is_active"] def __repr__(self) -> str: return ( From 9d417fbc98f68468bb584ba6f641020bc35fc401 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 4 Dec 2024 16:05:54 +0100 Subject: [PATCH 19/31] Fix `RBend` vs `Dipole` `defining_features` issue --- cheetah/accelerator/dipole.py | 78 ++++++++----- cheetah/accelerator/horizontal_corrector.py | 2 +- cheetah/accelerator/quadrupole.py | 4 +- cheetah/accelerator/rbend.py | 121 +++++++++++++++++--- cheetah/accelerator/solenoid.py | 2 +- cheetah/accelerator/vertical_corrector.py | 4 +- cheetah/converters/bmad.py | 4 +- cheetah/converters/elegant.py | 12 +- cheetah/converters/nxtables.py | 2 +- cheetah/converters/ocelot.py | 12 +- tests/test_bmad_conversion.py | 8 +- tests/test_compare_ocelot.py | 4 +- tests/test_dipole.py | 4 +- tests/test_elegant_conversion.py | 10 +- 14 files changed, 185 insertions(+), 82 deletions(-) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index 6f253be6..9ad343fb 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -49,8 +49,8 @@ def __init__( length: torch.Tensor, angle: Optional[torch.Tensor] = None, k1: Optional[torch.Tensor] = None, - e1: Optional[torch.Tensor] = None, - e2: Optional[torch.Tensor] = None, + dipole_e1: Optional[torch.Tensor] = None, + dipole_e2: Optional[torch.Tensor] = None, tilt: Optional[torch.Tensor] = None, gap: Optional[torch.Tensor] = None, gap_exit: Optional[torch.Tensor] = None, @@ -68,8 +68,8 @@ def __init__( length, angle, k1, - e1, - e2, + dipole_e1, + dipole_e2, tilt, gap, gap_exit, @@ -88,7 +88,7 @@ def __init__( ( torch.as_tensor(angle, **factory_kwargs) if angle is not None - else torch.zeros_like(self.length) + else torch.tensor(0.0, **factory_kwargs) ), ) self.register_buffer( @@ -96,23 +96,23 @@ def __init__( ( torch.as_tensor(k1, **factory_kwargs) if k1 is not None - else torch.zeros_like(self.length) + else torch.tensor(0.0, **factory_kwargs) ), ) self.register_buffer( - "e1", + "_e1", ( - torch.as_tensor(e1, **factory_kwargs) - if e1 is not None - else torch.zeros_like(self.length) + torch.as_tensor(dipole_e1, **factory_kwargs) + if dipole_e1 is not None + else torch.tensor(0.0, **factory_kwargs) ), ) self.register_buffer( - "e2", + "_e2", ( - torch.as_tensor(e2, **factory_kwargs) - if e2 is not None - else torch.zeros_like(self.length) + torch.as_tensor(dipole_e2, **factory_kwargs) + if dipole_e2 is not None + else torch.tensor(0.0, **factory_kwargs) ), ) self.register_buffer( @@ -120,7 +120,7 @@ def __init__( ( torch.as_tensor(fringe_integral, **factory_kwargs) if fringe_integral is not None - else torch.zeros_like(self.length) + else torch.tensor(0.0, **factory_kwargs) ), ) self.register_buffer( @@ -128,7 +128,7 @@ def __init__( ( self.fringe_integral if fringe_integral_exit is None - else torch.as_tensor(fringe_integral_exit, **factory_kwargs) + else torch.tensor(fringe_integral_exit, **factory_kwargs) ), ) self.register_buffer( @@ -136,7 +136,7 @@ def __init__( ( torch.as_tensor(gap, **factory_kwargs) if gap is not None - else torch.zeros_like(self.length) + else torch.tensor(0.0, **factory_kwargs) ), ) self.register_buffer( @@ -152,7 +152,7 @@ def __init__( ( torch.as_tensor(tilt, **factory_kwargs) if tilt is not None - else torch.zeros_like(self.length) + else torch.tensor(0.0, **factory_kwargs) ), ) self.fringe_at = fringe_at @@ -163,6 +163,22 @@ def __init__( def hx(self) -> torch.Tensor: return torch.where(self.length == 0.0, 0.0, self.angle / self.length) + @property + def dipole_e1(self) -> torch.Tensor: + return self._e1 + + @dipole_e1.setter + def dipole_e1(self, value: torch.Tensor): + self._e1 = value + + @property + def dipole_e2(self) -> torch.Tensor: + return self._e2 + + @dipole_e2.setter + def dipole_e2(self, value: torch.Tensor): + self._e2 = value + @property def is_skippable(self) -> bool: return self.tracking_method == "cheetah" @@ -378,7 +394,7 @@ def _bmadx_fringe_linear( :return: px, py final Bmad cannonical coordinates. """ g = self.angle / self.length - e = self.e1 * (location == "entrance") + self.e2 * (location == "exit") + e = self._e1 * (location == "entrance") + self._e2 * (location == "exit") f_int = self.fringe_integral * ( location == "entrance" ) + self.fringe_integral_exit * (location == "exit") @@ -431,18 +447,18 @@ def _transfer_map_enter(self) -> torch.Tensor: device = self.length.device dtype = self.length.dtype - sec_e = 1.0 / torch.cos(self.e1) + sec_e = 1.0 / torch.cos(self._e1) phi = ( self.fringe_integral * self.hx * self.gap * sec_e - * (1 + torch.sin(self.e1) ** 2) + * (1 + torch.sin(self._e1) ** 2) ) tm = torch.eye(7, device=device, dtype=dtype).repeat(*phi.shape, 1, 1) - tm[..., 1, 0] = self.hx * torch.tan(self.e1) - tm[..., 3, 2] = -self.hx * torch.tan(self.e1 - phi) + tm[..., 1, 0] = self.hx * torch.tan(self._e1) + tm[..., 3, 2] = -self.hx * torch.tan(self._e1 - phi) return tm @@ -451,18 +467,18 @@ def _transfer_map_exit(self) -> torch.Tensor: device = self.length.device dtype = self.length.dtype - sec_e = 1.0 / torch.cos(self.e2) + sec_e = 1.0 / torch.cos(self._e2) phi = ( self.fringe_integral_exit * self.hx * self.gap * sec_e - * (1 + torch.sin(self.e2) ** 2) + * (1 + torch.sin(self._e2) ** 2) ) tm = torch.eye(7, device=device, dtype=dtype).repeat(*phi.shape, 1, 1) - tm[..., 1, 0] = self.hx * torch.tan(self.e2) - tm[..., 3, 2] = -self.hx * torch.tan(self.e2 - phi) + tm[..., 1, 0] = self.hx * torch.tan(self._e2) + tm[..., 3, 2] = -self.hx * torch.tan(self._e2 - phi) return tm @@ -476,8 +492,8 @@ def __repr__(self): f"{self.__class__.__name__}(length={repr(self.length)}, " + f"angle={repr(self.angle)}, " + f"k1={repr(self.k1)}, " - + f"e1={repr(self.e1)}," - + f"e2={repr(self.e2)}," + + f"dipole_e1={repr(self.dipole_e1)}," + + f"dipole_e2={repr(self.dipole_e2)}," + f"tilt={repr(self.tilt)}," + f"gap={repr(self.gap)}," + f"gap_exit={repr(self.gap_exit)}," @@ -495,8 +511,8 @@ def defining_features(self) -> list[str]: "length", "angle", "k1", - "e1", - "e2", + "dipole_e1", + "dipole_e2", "tilt", "gap", "gap_exit", diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index 36456aa0..670ca295 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -44,7 +44,7 @@ def __init__( ( torch.as_tensor(angle, **factory_kwargs) if angle is not None - else torch.zeros_like(self.length) + else torch.tensor(0.0, **factory_kwargs) ), ) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index cbc7e00c..775b3e41 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -55,7 +55,7 @@ def __init__( ( torch.as_tensor(k1, **factory_kwargs) if k1 is not None - else torch.zeros_like(self.length) + else torch.tensor(0.0, **factory_kwargs) ), ) self.register_buffer( @@ -71,7 +71,7 @@ def __init__( ( torch.as_tensor(tilt, **factory_kwargs) if tilt is not None - else torch.zeros_like(self.length) + else torch.tensor(0.0, **factory_kwargs) ), ) self.num_steps = num_steps diff --git a/cheetah/accelerator/rbend.py b/cheetah/accelerator/rbend.py index bb332b45..2862688c 100644 --- a/cheetah/accelerator/rbend.py +++ b/cheetah/accelerator/rbend.py @@ -1,9 +1,9 @@ -from typing import Optional +from typing import Literal, Optional import torch from cheetah.accelerator.dipole import Dipole -from cheetah.utils import UniqueNameGenerator +from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -17,14 +17,21 @@ class RBend(Dipole): :param k1: Focussing strength in 1/m^-2. :param e1: The angle of inclination of the entrance face [rad]. :param e2: The angle of inclination of the exit face [rad]. - :param tilt: Tilt of the magnet in x-y plane [rad]. - :param fringe_integral: Fringe field integral (of the enterance face). - :param fringe_integral_exit: (only set if different from `fint`) Fringe field - integral of the exit face. - :param gap: The magnet gap [m], NOTE in MAD and ELEGANT: HGAP = gap/2 + :param gap: The magnet gap in meters. Note that in MAD and ELEGANT: HGAP = gap/2. :param gap_exit: The magnet gap at the exit in meters. Note that in MAD and ELEGANT: HGAP = gap/2. Only set if different from `gap`. Only used with `"bmadx"` tracking method. + :param fringe_integral: Fringe field integral (of the enterance face). + :param fringe_integral_exit: Fringe field integral of the exit face. Only set if + different from `fringe_integral`. Only used with `"bmadx"` tracking method. + :param fringe_at: Where to apply the fringe fields for `"bmadx"` tracking. The + available options are: + - "neither": Do not apply fringe fields. + - "entrance": Apply fringe fields at the entrance end. + - "exit": Apply fringe fields at the exit end. + - "both": Apply fringe fields at both ends. + :param fringe_type: Type of fringe field for `"bmadx"` tracking. Currently only + supports `"linear_edge"`. :param name: Unique identifier of the element. """ @@ -33,33 +40,111 @@ def __init__( length: Optional[torch.Tensor], angle: Optional[torch.Tensor] = None, k1: Optional[torch.Tensor] = None, - e1: Optional[torch.Tensor] = None, - e2: Optional[torch.Tensor] = None, + rbend_e1: Optional[torch.Tensor] = None, + rbend_e2: Optional[torch.Tensor] = None, tilt: Optional[torch.Tensor] = None, - fringe_integral: Optional[torch.Tensor] = None, - fringe_integral_exit: Optional[torch.Tensor] = None, gap: Optional[torch.Tensor] = None, gap_exit: Optional[torch.Tensor] = None, + fringe_integral: Optional[torch.Tensor] = None, + fringe_integral_exit: Optional[torch.Tensor] = None, + fringe_at: Literal["neither", "entrance", "exit", "both"] = "both", + fringe_type: Literal["linear_edge"] = "linear_edge", + tracking_method: Literal["cheetah", "bmadx"] = "cheetah", name: Optional[str] = None, device=None, dtype=None, ): + # Set default values needed for conversion from RBend to Dipole + device, dtype = verify_device_and_dtype( + [ + length, + angle, + k1, + rbend_e1, + rbend_e2, + tilt, + gap, + gap_exit, + fringe_integral, + fringe_integral_exit, + ], + device, + dtype, + ) + factory_kwargs = {"device": device, "dtype": dtype} + + angle = ( + torch.as_tensor(angle, **factory_kwargs) + if angle is not None + else torch.tensor(0.0, **factory_kwargs) + ) + rbend_e1 = ( + torch.as_tensor(rbend_e1, **factory_kwargs) + if rbend_e1 is not None + else torch.tensor(0.0, **factory_kwargs) + ) + rbend_e2 = ( + torch.as_tensor(rbend_e2, **factory_kwargs) + if rbend_e2 is not None + else torch.tensor(0.0, **factory_kwargs) + ) + super().__init__( length=length, angle=angle, k1=k1, - e1=e1, - e2=e2, + dipole_e1=rbend_e1 + angle / 2, + dipole_e2=rbend_e2 + angle / 2, tilt=tilt, - fringe_integral=fringe_integral, - fringe_integral_exit=fringe_integral_exit, gap=gap, gap_exit=gap_exit, + fringe_integral=fringe_integral, + fringe_integral_exit=fringe_integral_exit, + fringe_at=fringe_at, + fringe_type=fringe_type, + tracking_method=tracking_method, name=name, device=device, dtype=dtype, ) - # Rectangular bend - self.e1 = self.e1 + self.angle / 2 - self.e2 = self.e2 + self.angle / 2 + @property + def rbend_e1(self): + return self.dipole_e1 - self.angle / 2 + + @rbend_e1.setter + def rbend_e1(self, value): + self.dipole_e1 = value + self.angle / 2 + + @property + def rbend_e2(self): + return self.dipole_e2 - self.angle / 2 + + @rbend_e2.setter + def rbend_e2(self, value): + self.dipole_e2 = value + self.angle / 2 + + def __repr__(self): + return ( + f"{self.__class__.__name__}(length={repr(self.length)}, " + + f"angle={repr(self.angle)}, " + + f"k1={repr(self.k1)}, " + + f"rbend_e1={repr(self.rbend_e1)}," + + f"rbend_e2={repr(self.rbend_e2)}," + + f"tilt={repr(self.tilt)}," + + f"gap={repr(self.gap)}," + + f"gap_exit={repr(self.gap_exit)}," + + f"fringe_integral={repr(self.fringe_integral)}," + + f"fringe_integral_exit={repr(self.fringe_integral_exit)}," + + f"fringe_at={repr(self.fringe_at)}," + + f"fringe_type={repr(self.fringe_type)}," + + f"tracking_method={repr(self.tracking_method)}, " + + f"name={repr(self.name)})" + ) + + @property + def defining_features(self): + dipole_features = super().defining_features + dipole_features.remove("dipole_e1") + dipole_features.remove("dipole_e2") + return dipole_features + ["rbend_e1", "rbend_e2"] diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index d59ccd16..97c0e389 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -53,7 +53,7 @@ def __init__( ( torch.as_tensor(k, **factory_kwargs) if k is not None - else torch.zeros_like(self.length) + else torch.tensor(0.0, **factory_kwargs) ), ) self.register_buffer( diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index 78e78fb3..bb5023c4 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -47,7 +47,7 @@ def __init__( ( torch.as_tensor(angle, **factory_kwargs) if angle is not None - else torch.zeros_like(self.length) + else torch.tensor(0.0, **factory_kwargs) ), ) @@ -86,7 +86,7 @@ def split(self, resolution: torch.Tensor) -> list[Element]: dtype=self.length.dtype, device=self.length.device, ) - for i in range(num_splits) + for _ in range(num_splits) ] def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None: diff --git a/cheetah/converters/bmad.py b/cheetah/converters/bmad.py index 6c139570..a6c34ef6 100644 --- a/cheetah/converters/bmad.py +++ b/cheetah/converters/bmad.py @@ -146,8 +146,8 @@ def convert_element( length=torch.tensor(bmad_parsed["l"]), gap=torch.tensor(2 * bmad_parsed.get("hgap", 0.0)), angle=torch.tensor(bmad_parsed.get("angle", 0.0)), - e1=torch.tensor(bmad_parsed["e1"]), - e2=torch.tensor(bmad_parsed.get("e2", 0.0)), + dipole_e1=torch.tensor(bmad_parsed["e1"]), + dipole_e2=torch.tensor(bmad_parsed.get("e2", 0.0)), tilt=torch.tensor(bmad_parsed.get("ref_tilt", 0.0)), fringe_integral=torch.tensor(bmad_parsed.get("fint", 0.0)), fringe_integral_exit=( diff --git a/cheetah/converters/elegant.py b/cheetah/converters/elegant.py index 4927593f..a1073fba 100644 --- a/cheetah/converters/elegant.py +++ b/cheetah/converters/elegant.py @@ -362,8 +362,8 @@ def convert_element( length=torch.tensor(parsed["l"]), angle=torch.tensor(parsed.get("angle", 0.0)), k1=torch.tensor(parsed.get("k1", 0.0)), - e1=torch.tensor(parsed.get("e1", 0.0)), - e2=torch.tensor(parsed.get("e2", 0.0)), + dipole_e1=torch.tensor(parsed.get("e1", 0.0)), + dipole_e2=torch.tensor(parsed.get("e2", 0.0)), tilt=torch.tensor(parsed.get("tilt", 0.0)), name=name, device=device, @@ -377,8 +377,8 @@ def convert_element( return cheetah.RBend( length=torch.tensor(parsed["l"]), angle=torch.tensor(parsed.get("angle", 0.0)), - e1=torch.tensor(parsed.get("e1", 0.0)), - e2=torch.tensor(parsed.get("e2", 0.0)), + rbend_e1=torch.tensor(parsed.get("e1", 0.0)), + rbend_e2=torch.tensor(parsed.get("e2", 0.0)), tilt=torch.tensor(parsed.get("tilt", 0.0)), name=name, device=device, @@ -413,8 +413,8 @@ def convert_element( length=torch.tensor(parsed["l"]), angle=torch.tensor(parsed.get("angle", 0.0)), k1=torch.tensor(parsed.get("k1", 0.0)), - e1=torch.tensor(parsed.get("e1", 0.0)), - e2=torch.tensor(parsed.get("e2", 0.0)), + dipole_e1=torch.tensor(parsed.get("e1", 0.0)), + dipole_e2=torch.tensor(parsed.get("e2", 0.0)), tilt=torch.tensor(parsed.get("tilt", 0.0)), name=name, device=device, diff --git a/cheetah/converters/nxtables.py b/cheetah/converters/nxtables.py index 6531e170..f7ae774f 100644 --- a/cheetah/converters/nxtables.py +++ b/cheetah/converters/nxtables.py @@ -146,7 +146,7 @@ def translate_element(row: list[str], header: list[str]) -> Optional[Dict]: name=name, length=torch.tensor(0.43852543421396856), angle=torch.tensor(0.8203047484373349), - e2=torch.tensor(-0.7504915783575616), + dipole_e2=torch.tensor(-0.7504915783575616), ) elif class_name == "MQZM": element = cheetah.Quadrupole(name=name, length=torch.tensor(0.122)) diff --git a/cheetah/converters/ocelot.py b/cheetah/converters/ocelot.py index 51fc444f..d1577563 100644 --- a/cheetah/converters/ocelot.py +++ b/cheetah/converters/ocelot.py @@ -72,8 +72,8 @@ def convert_element_to_cheetah( return cheetah.Dipole( length=torch.tensor(element.l, dtype=torch.float32), angle=torch.tensor(element.angle, dtype=torch.float32), - e1=torch.tensor(element.e1, dtype=torch.float32), - e2=torch.tensor(element.e2, dtype=torch.float32), + dipole_e1=torch.tensor(element.e1, dtype=torch.float32), + dipole_e2=torch.tensor(element.e2, dtype=torch.float32), tilt=torch.tensor(element.tilt, dtype=torch.float32), fringe_integral=torch.tensor(element.fint, dtype=torch.float32), fringe_integral_exit=torch.tensor(element.fintx, dtype=torch.float32), @@ -86,8 +86,8 @@ def convert_element_to_cheetah( return cheetah.Dipole( length=torch.tensor(element.l, dtype=torch.float32), angle=torch.tensor(element.angle, dtype=torch.float32), - e1=torch.tensor(element.e1, dtype=torch.float32), - e2=torch.tensor(element.e2, dtype=torch.float32), + dipole_e1=torch.tensor(element.e1, dtype=torch.float32), + dipole_e2=torch.tensor(element.e2, dtype=torch.float32), tilt=torch.tensor(element.tilt, dtype=torch.float32), fringe_integral=torch.tensor(element.fint, dtype=torch.float32), fringe_integral_exit=torch.tensor(element.fintx, dtype=torch.float32), @@ -100,8 +100,8 @@ def convert_element_to_cheetah( return cheetah.RBend( length=torch.tensor(element.l, dtype=torch.float32), angle=torch.tensor(element.angle, dtype=torch.float32), - e1=torch.tensor(element.e1, dtype=torch.float32) - element.angle / 2, - e2=torch.tensor(element.e2, dtype=torch.float32) - element.angle / 2, + rbend_e1=torch.tensor(element.e1, dtype=torch.float32) - element.angle / 2, + rbend_e2=torch.tensor(element.e2, dtype=torch.float32) - element.angle / 2, tilt=torch.tensor(element.tilt, dtype=torch.float32), fringe_integral=torch.tensor(element.fint, dtype=torch.float32), fringe_integral_exit=torch.tensor(element.fintx, dtype=torch.float32), diff --git a/tests/test_bmad_conversion.py b/tests/test_bmad_conversion.py index 5a6477ad..d3d05761 100644 --- a/tests/test_bmad_conversion.py +++ b/tests/test_bmad_conversion.py @@ -15,7 +15,7 @@ def test_bmad_tutorial(): [ cheetah.Drift(length=torch.tensor([0.5]), name="d"), cheetah.Dipole( - length=torch.tensor([0.5]), e1=torch.tensor([0.1]), name="b" + length=torch.tensor([0.5]), dipole_e1=torch.tensor([0.1]), name="b" ), # TODO: What are g and dg? cheetah.Quadrupole( length=torch.tensor([0.6]), k1=torch.tensor([0.23]), name="q" @@ -31,7 +31,7 @@ def test_bmad_tutorial(): ] assert converted.d.length == correct.d.length assert converted.b.length == correct.b.length - assert converted.b.e1 == correct.b.e1 + assert converted.b.dipole_e1 == correct.b.dipole_e1 assert converted.q.length == correct.q.length assert converted.q.k1 == correct.q.k1 @@ -64,7 +64,7 @@ def test_device_passing(device: torch.device): # Check that the properties of the loaded elements are on the correct device assert converted.d.length.device.type == device.type assert converted.b.length.device.type == device.type - assert converted.b.e1.device.type == device.type + assert converted.b.dipole_e1.device.type == device.type assert converted.q.length.device.type == device.type assert converted.q.k1.device.type == device.type @@ -80,6 +80,6 @@ def test_dtype_passing(dtype: torch.dtype): # Check that the properties of the loaded elements are of the correct dtype assert converted.d.length.dtype == dtype assert converted.b.length.dtype == dtype - assert converted.b.e1.dtype == dtype + assert converted.b.dipole_e1.dtype == dtype assert converted.q.length.dtype == dtype assert converted.q.k1.dtype == dtype diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index 1478454e..4a336a69 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -115,8 +115,8 @@ def test_dipole_with_fringe_field_and_tilt(): fringe_integral=torch.tensor(0.1), gap=torch.tensor(0.2), tilt=torch.tensor(tilt_angle), - e1=torch.tensor(bend_angle / 2), - e2=torch.tensor(bend_angle / 2), + dipole_e1=torch.tensor(bend_angle / 2), + dipole_e2=torch.tensor(bend_angle / 2), ) outgoing_beam = cheetah_dipole(incoming_beam) diff --git a/tests/test_dipole.py b/tests/test_dipole.py index 08d016e4..fcd1c6d4 100644 --- a/tests/test_dipole.py +++ b/tests/test_dipole.py @@ -127,8 +127,8 @@ def test_dipole_bmadx_tracking(dtype): dipole_cheetah_bmadx = Dipole( length=torch.tensor([0.5]), angle=angle, - e1=e1, - e2=e2, + dipole_e1=e1, + dipole_e2=e2, tilt=torch.tensor([0.1], dtype=dtype), fringe_integral=torch.tensor([0.5]), fringe_integral_exit=torch.tensor([0.5]), diff --git a/tests/test_elegant_conversion.py b/tests/test_elegant_conversion.py index 01fdfbaf..00a01fbe 100644 --- a/tests/test_elegant_conversion.py +++ b/tests/test_elegant_conversion.py @@ -19,7 +19,9 @@ def test_fodo(): ), cheetah.Drift(name="d1", length=torch.tensor(1.0)), cheetah.Marker(name="m1"), - cheetah.Dipole(name="s1", length=torch.tensor(0.3), e1=torch.tensor(0.25)), + cheetah.Dipole( + name="s1", length=torch.tensor(0.3), dipole_e1=torch.tensor(0.25) + ), cheetah.Drift(name="d1", length=torch.tensor(1.0)), cheetah.Quadrupole( name="q2", length=torch.tensor(0.2), k1=torch.tensor(-3.0) @@ -41,7 +43,7 @@ def test_fodo(): assert torch.isclose(converted.d1[i].length, correct_lattice.d1[i].length) assert torch.isclose(converted.d2.length, correct_lattice.d2.length) assert torch.isclose(converted.s1.length, correct_lattice.s1.length) - assert torch.isclose(converted.s1.e1, correct_lattice.s1.e1) + assert torch.isclose(converted.s1.dipole_e1, correct_lattice.s1.dipole_e1) def test_cavity_import(): @@ -110,7 +112,7 @@ def test_device_passing(device: torch.device): assert [d.length.device.type for d in converted.d1] == [device.type, device.type] assert converted.d2.length.device.type == device.type assert converted.s1.length.device.type == device.type - assert converted.s1.e1.device.type == device.type + assert converted.s1.dipole_e1.device.type == device.type @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @@ -129,4 +131,4 @@ def test_dtype_passing(dtype: torch.dtype): assert [d.length.dtype for d in converted.d1] == [dtype, dtype] assert converted.d2.length.dtype == dtype assert converted.s1.length.dtype == dtype - assert converted.s1.e1.dtype == dtype + assert converted.s1.dipole_e1.dtype == dtype From 456cce02fb7717025e8976f6b23c93045e08c82c Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 4 Dec 2024 16:32:54 +0100 Subject: [PATCH 20/31] Add test that finds `CustomTransferMap` cloning issue --- tests/test_lattice_json.py | 50 ++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/tests/test_lattice_json.py b/tests/test_lattice_json.py index b03eec90..65cf3939 100644 --- a/tests/test_lattice_json.py +++ b/tests/test_lattice_json.py @@ -1,14 +1,16 @@ -from cheetah import Segment +import torch + +import cheetah from .resources import ARESlatticeStage3v1_9 as ares -def test_save_and_reload(tmp_path): +def test_save_and_reload_ares_example(tmp_path): """ Test that saving Cheetah `Segment` to LatticeJSON works and that it can be reloaded - correctly. + correctly at the example of the full ARES lattice. """ - original_segment = Segment.from_ocelot(ares.cell, name="ARES_Segment") + original_segment = cheetah.Segment.from_ocelot(ares.cell, name="ARES_Segment") original_segment.to_lattice_json( str(tmp_path / "ares_lattice.json"), @@ -16,7 +18,9 @@ def test_save_and_reload(tmp_path): info="Save and reload test for Cheetah using the ARES lattice", ) - reloaded_segment = Segment.from_lattice_json(str(tmp_path / "ares_lattice.json")) + reloaded_segment = cheetah.Segment.from_lattice_json( + str(tmp_path / "ares_lattice.json") + ) assert original_segment.name == reloaded_segment.name assert len(original_segment.elements) == len(reloaded_segment.elements) @@ -28,3 +32,39 @@ def test_save_and_reload(tmp_path): ): assert original_element.name == reloaded_element.name assert original_element.__class__ == reloaded_element.__class__ + + +def test_save_and_reload_custom_transfer_map(tmp_path): + """ + Test that saving and reloading a `CustomTransferMap` works. `CustomTransferMap` + never appears in the ARES lattice and must therefore be tested separately. + """ + custom_transfer_map_element = cheetah.CustomTransferMap( + transfer_map=torch.eye(7, 7), + length=torch.tensor(1.0), + name="my_custom_transfer_map_element", + ) + segment = cheetah.Segment([custom_transfer_map_element], name="test_segment") + + segment.to_lattice_json( + str(tmp_path / "custom_transfer_map_lattice.json"), + title="Custom Transfer Map LatticeJSON", + info="Save and reload test for Cheetah using a custom transfer map", + ) + + reloaded_segment = cheetah.Segment.from_lattice_json( + str(tmp_path / "custom_transfer_map_lattice.json") + ) + + # I really only care that the transfer map element is recovered correctly, the + # segment was tested in a different test. + reloaded_custom_transfer_map_element = reloaded_segment.elements[0] + + assert torch.allclose( + custom_transfer_map_element._transfer_map, + reloaded_custom_transfer_map_element._transfer_map, + ) + assert torch.allclose( + custom_transfer_map_element.length, reloaded_custom_transfer_map_element.length + ) + assert custom_transfer_map_element.name == reloaded_custom_transfer_map_element.name From 21023b9617e07ee207cda95da01b8862bb57237f Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 4 Dec 2024 16:36:10 +0100 Subject: [PATCH 21/31] Fix `deepcopy` test failure resulting from merging `master` --- cheetah/accelerator/screen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index fb9e1676..e41453d4 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -204,7 +204,7 @@ def track(self, incoming: Beam) -> Beam: ), ) else: - return deepcopy(incoming) + return incoming.clone() @property def reading(self) -> torch.Tensor: From cd5652a59ebc2c3a4ff26750352508bc497069e1 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 4 Dec 2024 16:45:39 +0100 Subject: [PATCH 22/31] Fix issue with `CustomTransferMap` saving and cloning --- cheetah/accelerator/custom_transfer_map.py | 29 ++++++++++------------ cheetah/converters/elegant.py | 2 +- tests/test_elegant_conversion.py | 2 +- tests/test_lattice_json.py | 6 ++--- 4 files changed, 18 insertions(+), 21 deletions(-) diff --git a/cheetah/accelerator/custom_transfer_map.py b/cheetah/accelerator/custom_transfer_map.py index 36a8ec05..896c60e1 100644 --- a/cheetah/accelerator/custom_transfer_map.py +++ b/cheetah/accelerator/custom_transfer_map.py @@ -18,28 +18,31 @@ class CustomTransferMap(Element): def __init__( self, - transfer_map: torch.Tensor, + predefined_transfer_map: torch.Tensor, length: Optional[torch.Tensor] = None, name: Optional[str] = None, device=None, dtype=None, ) -> None: - device, dtype = verify_device_and_dtype([transfer_map, length], device, dtype) + device, dtype = verify_device_and_dtype( + [predefined_transfer_map, length], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) - assert isinstance(transfer_map, torch.Tensor) - assert transfer_map.shape[-2:] == (7, 7) + assert isinstance(predefined_transfer_map, torch.Tensor) + assert predefined_transfer_map.shape[-2:] == (7, 7) self.register_buffer( - "_transfer_map", torch.as_tensor(transfer_map, **factory_kwargs) + "predefined_transfer_map", + torch.as_tensor(predefined_transfer_map, **factory_kwargs), ) self.register_buffer( "length", ( torch.as_tensor(length, **factory_kwargs) if length is not None - else torch.zeros(transfer_map.shape[:-2], **factory_kwargs) + else torch.zeros(predefined_transfer_map.shape[:-2], **factory_kwargs) ), ) @@ -83,7 +86,7 @@ def from_merging_elements( ) def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - return self._transfer_map + return self.predefined_transfer_map @property def is_skippable(self) -> bool: @@ -91,25 +94,19 @@ def is_skippable(self) -> bool: def __repr__(self): return ( - f"{self.__class__.__name__}(transfer_map={repr(self._transfer_map)}, " + f"{self.__class__.__name__}(" + + f"predefined_transfer_map={repr(self.predefined_transfer_map)}, " + f"length={repr(self.length)}, " + f"name={repr(self.name)})" ) @property def defining_features(self) -> list[str]: - return super().defining_features + ["transfer_map"] + return super().defining_features + ["length", "predefined_transfer_map"] def split(self, resolution: torch.Tensor) -> list[Element]: return [self] - def clone(self) -> "CustomTransferMap": - return CustomTransferMap( - transfer_map=self._transfer_map.clone(), - length=self.length.clone(), - name=self.name, - ) - def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> None: plot_s = s[vector_idx] if s.dim() > 0 else s plot_length = self.length[vector_idx] if self.length.dim() > 0 else self.length diff --git a/cheetah/converters/elegant.py b/cheetah/converters/elegant.py index a1073fba..b0c462e6 100644 --- a/cheetah/converters/elegant.py +++ b/cheetah/converters/elegant.py @@ -250,7 +250,7 @@ def convert_element( return cheetah.CustomTransferMap( length=torch.tensor(parsed["l"]), - transfer_map=R, + predefined_transfer_map=R, name=name, device=device, dtype=dtype, diff --git a/tests/test_elegant_conversion.py b/tests/test_elegant_conversion.py index 00a01fbe..1f482ed3 100644 --- a/tests/test_elegant_conversion.py +++ b/tests/test_elegant_conversion.py @@ -76,7 +76,7 @@ def test_custom_transfer_map_import(): ] ) - assert torch.allclose(converted.c1e._transfer_map, correct_transfer_map) + assert torch.allclose(converted.c1e.predefined_transfer_map, correct_transfer_map) @pytest.mark.parametrize( diff --git a/tests/test_lattice_json.py b/tests/test_lattice_json.py index 65cf3939..735e1830 100644 --- a/tests/test_lattice_json.py +++ b/tests/test_lattice_json.py @@ -40,7 +40,7 @@ def test_save_and_reload_custom_transfer_map(tmp_path): never appears in the ARES lattice and must therefore be tested separately. """ custom_transfer_map_element = cheetah.CustomTransferMap( - transfer_map=torch.eye(7, 7), + predefined_transfer_map=torch.eye(7, 7), length=torch.tensor(1.0), name="my_custom_transfer_map_element", ) @@ -61,8 +61,8 @@ def test_save_and_reload_custom_transfer_map(tmp_path): reloaded_custom_transfer_map_element = reloaded_segment.elements[0] assert torch.allclose( - custom_transfer_map_element._transfer_map, - reloaded_custom_transfer_map_element._transfer_map, + custom_transfer_map_element.predefined_transfer_map, + reloaded_custom_transfer_map_element.predefined_transfer_map, ) assert torch.allclose( custom_transfer_map_element.length, reloaded_custom_transfer_map_element.length From 3bce621e1be53a784774d8ab75cb9748d5585d8e Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 4 Dec 2024 17:29:59 +0100 Subject: [PATCH 23/31] Some cleanup --- tests/test_clone.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/tests/test_clone.py b/tests/test_clone.py index 29f51962..0119343b 100644 --- a/tests/test_clone.py +++ b/tests/test_clone.py @@ -20,9 +20,7 @@ ], ) def test_element_buffer_location(ElementClass): - """ - Test that the buffers of cloned elements do not share memory. - """ + """Test that the buffers of cloned elements do not share memory.""" element = ElementClass(length=torch.tensor(1.0)) clone = element.clone() @@ -31,17 +29,9 @@ def test_element_buffer_location(ElementClass): assert not buffer.data_ptr() == buffer_clone.data_ptr() -@pytest.mark.parametrize( - "BeamClass", - [ - cheetah.ParameterBeam, - cheetah.ParticleBeam, - ], -) +@pytest.mark.parametrize("BeamClass", [cheetah.ParameterBeam, cheetah.ParticleBeam]) def test_beam_buffer_location(BeamClass): - """ - Test that the buffers of clones beams do not share memory. - """ + """Test that the buffers of clones beams do not share memory.""" beam = BeamClass.from_parameters() clone = beam.clone() From d6ec8fbaefe0be0b63e1ab69343bc589d4e554ef Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 4 Dec 2024 17:33:21 +0100 Subject: [PATCH 24/31] Fix misspelling of "leaf" --- tests/test_plotting.py | 6 ++---- tests/test_tracking_nonleave_tensors.py | 25 ++++++++++--------------- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 6968d160..f3b3f01b 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -105,10 +105,8 @@ def test_reference_particle_plot_vectorized_2d(): segment.plot_overview(incoming=incoming, resolution=0.1, vector_idx=(0, 2)) -def test_plotting_with_nonleave_tensors(): - """ - Test that the plotting routines can handle elements with non-leave tensors. - """ +def test_plotting_with_nonleaf_tensors(): + """Test that the plotting routines can handle elements with non-leaf tensors.""" segment = cheetah.Segment( elements=[ cheetah.Drift(length=torch.tensor(1.0, requires_grad=True)), diff --git a/tests/test_tracking_nonleave_tensors.py b/tests/test_tracking_nonleave_tensors.py index dc5f9fb2..0db9a3d7 100644 --- a/tests/test_tracking_nonleave_tensors.py +++ b/tests/test_tracking_nonleave_tensors.py @@ -19,10 +19,10 @@ cheetah.VerticalCorrector, ], ) -def test_nonleave_tracking(ElementClass): +def test_nonleaf_tracking(ElementClass): """ - Test that a beam with non-leave tensors as elements can be tracked through - elements with length parameter. + Test that a beam with non-leaf tensors as elements can be tracked through elements + with length parameter. """ beam = cheetah.ParticleBeam.from_parameters() @@ -36,21 +36,16 @@ def test_nonleave_tracking(ElementClass): @pytest.mark.parametrize( - "ElementClass", - [ - cheetah.Aperture, - cheetah.BPM, - cheetah.Screen, - ], + "ElementClass", [cheetah.Aperture, cheetah.BPM, cheetah.Screen] ) -def test_nonleave_lenghtless_elements(ElementClass): +def test_nonleaf_lenghtless_elements(ElementClass): """ - Test that a beam with non-leave tensors as elements can be tracked through - elements without length parameter. + Test that a beam with non-leaf tensors as elements can be tracked through elements + without length parameter. - The split into lengthless elements is necessary since there is no common - constructor for all element classes. Some require a length, some cannot - handle a length argument. + The split into lengthless elements is necessary since there is no common constructor + for all element classes. Some require a length, some cannot handle a length + argument. """ beam = cheetah.ParticleBeam.from_parameters() From fdf4b9b0ef8b4ae3c32db1944096360d6597a4ea Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 4 Dec 2024 17:34:18 +0100 Subject: [PATCH 25/31] Another "leaf" spelling fix --- ...cking_nonleave_tensors.py => test_tracking_nonleaf_tensors.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_tracking_nonleave_tensors.py => test_tracking_nonleaf_tensors.py} (100%) diff --git a/tests/test_tracking_nonleave_tensors.py b/tests/test_tracking_nonleaf_tensors.py similarity index 100% rename from tests/test_tracking_nonleave_tensors.py rename to tests/test_tracking_nonleaf_tensors.py From 63df47d93a127b5d4b7de9f47ce8d971d6ab2260 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 4 Dec 2024 17:36:08 +0100 Subject: [PATCH 26/31] Fix misspelling of "feature" --- cheetah/latticejson.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cheetah/latticejson.py b/cheetah/latticejson.py index 4f83edd3..bb9b5295 100644 --- a/cheetah/latticejson.py +++ b/cheetah/latticejson.py @@ -31,9 +31,9 @@ def convert_element(element: "cheetah.Element"): :return: Tuple of element name, element class, and element parameters """ params = { - feauture: feature2nontorch(getattr(element, feauture)) - for feauture in element.defining_features - if feauture != "name" + feature: feature2nontorch(getattr(element, feature)) + for feature in element.defining_features + if feature != "name" } return element.name, element.__class__.__name__, params From 0cf9c5cadaa356db10d7a00c1241a8a64c1097da Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 4 Dec 2024 17:37:26 +0100 Subject: [PATCH 27/31] Move non-lead tests to the same file as other differentiability tests --- tests/test_differentiable.py | 55 ++++++++++++++++++++++++ tests/test_tracking_nonleaf_tensors.py | 58 -------------------------- 2 files changed, 55 insertions(+), 58 deletions(-) delete mode 100644 tests/test_tracking_nonleaf_tensors.py diff --git a/tests/test_differentiable.py b/tests/test_differentiable.py index c60e07dd..5bb736d3 100644 --- a/tests/test_differentiable.py +++ b/tests/test_differentiable.py @@ -1,3 +1,4 @@ +import pytest import torch from torch import nn @@ -92,3 +93,57 @@ def test_ea_incoming_particle_beam(): outgoing_beam = ea.track(incoming_beam) assert outgoing_beam.particles.grad_fn is not None + + +@pytest.mark.parametrize( + "ElementClass", + [ + cheetah.Cavity, + cheetah.Dipole, + cheetah.Drift, + cheetah.HorizontalCorrector, + cheetah.Quadrupole, + cheetah.RBend, + cheetah.Solenoid, + cheetah.TransverseDeflectingCavity, + cheetah.Undulator, + cheetah.VerticalCorrector, + ], +) +def test_nonleaf_tracking(ElementClass): + """ + Test that a beam with non-leaf tensors as elements can be tracked through elements + with length parameter. + """ + beam = cheetah.ParticleBeam.from_parameters() + + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(1.0, requires_grad=True)), + ElementClass(length=torch.tensor(2.0)), + ] + ) + segment.track(beam) + + +@pytest.mark.parametrize( + "ElementClass", [cheetah.Aperture, cheetah.BPM, cheetah.Screen] +) +def test_nonleaf_lenghtless_elements(ElementClass): + """ + Test that a beam with non-leaf tensors as elements can be tracked through elements + without length parameter. + + The split into lengthless elements is necessary since there is no common constructor + for all element classes. Some require a length, some cannot handle a length + argument. + """ + beam = cheetah.ParticleBeam.from_parameters() + + segment = cheetah.Segment( + elements=[ + cheetah.Drift(length=torch.tensor(1.0, requires_grad=True)), + ElementClass(is_active=True), + ] + ) + segment.track(beam) diff --git a/tests/test_tracking_nonleaf_tensors.py b/tests/test_tracking_nonleaf_tensors.py deleted file mode 100644 index 0db9a3d7..00000000 --- a/tests/test_tracking_nonleaf_tensors.py +++ /dev/null @@ -1,58 +0,0 @@ -import pytest -import torch - -import cheetah - - -@pytest.mark.parametrize( - "ElementClass", - [ - cheetah.Cavity, - cheetah.Dipole, - cheetah.Drift, - cheetah.HorizontalCorrector, - cheetah.Quadrupole, - cheetah.RBend, - cheetah.Solenoid, - cheetah.TransverseDeflectingCavity, - cheetah.Undulator, - cheetah.VerticalCorrector, - ], -) -def test_nonleaf_tracking(ElementClass): - """ - Test that a beam with non-leaf tensors as elements can be tracked through elements - with length parameter. - """ - beam = cheetah.ParticleBeam.from_parameters() - - segment = cheetah.Segment( - elements=[ - cheetah.Drift(length=torch.tensor(1.0, requires_grad=True)), - ElementClass(length=torch.tensor(2.0)), - ] - ) - segment.track(beam) - - -@pytest.mark.parametrize( - "ElementClass", [cheetah.Aperture, cheetah.BPM, cheetah.Screen] -) -def test_nonleaf_lenghtless_elements(ElementClass): - """ - Test that a beam with non-leaf tensors as elements can be tracked through elements - without length parameter. - - The split into lengthless elements is necessary since there is no common constructor - for all element classes. Some require a length, some cannot handle a length - argument. - """ - beam = cheetah.ParticleBeam.from_parameters() - - segment = cheetah.Segment( - elements=[ - cheetah.Drift(length=torch.tensor(1.0, requires_grad=True)), - ElementClass(is_active=True), - ] - ) - segment.track(beam) From 6d6586a2d9144fea0d1d397e9f569336b51e4d46 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 4 Dec 2024 17:48:45 +0100 Subject: [PATCH 28/31] Update changelog because we forgot so far --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index dec4d198..2703c591 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ This is a major release with significant upgrades under the hood of Cheetah. Des - The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163, #284) (@cr-xu, @hespe) - `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan) - The way `dtype`s are determined is now more in line with PyTorch's conventions. This may cause different-than-expected `dtype`s in old code. (see #254) (@hespe, @jank324) +- `e1` and `e2` in `Dipole` and `RBend` have been renamed and made more consistent between the different magnet types. They now have prefixes `dipole_` and `rbend_` respectively. (see #289) (@hespe, @jank324) +- The `_transfer_map` property of `CustomTransferMap` has been renamed to `predefined_transfer_map`. (see #289) (@hespe, @jank324) ### 🚀 Features @@ -26,6 +28,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des - Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240, #278 #296) (@jp-ga, @cr-xu, @jank324) - `Dipole` and `RBend` now take a focusing moment `k1` (see #235, #247) (@hespe) - Implement a converter for lattice files imported from Elegant (see #222, #251, #273, #281) (@hespe, @jank324) +- `Beam` and `Element` objects now have a `.clone()` method to create a deep copy (see #289) (@hespe, @jank324) ### 🐛 Bug fixes From 05763fe92d34affddaf88a51f1a1c9ffdc00ed40 Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Thu, 5 Dec 2024 11:57:08 +0100 Subject: [PATCH 29/31] Add survival probability to clone implementation --- cheetah/particles/particle_beam.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 8991b94c..4b547cf5 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -1079,6 +1079,7 @@ def clone(self) -> "ParticleBeam": particles=self.particles.clone(), energy=self.energy.clone(), particle_charges=self.particle_charges.clone(), + survival_probabilities=self.survival_probabilities.clone(), ) def __repr__(self) -> str: From b520eae4d45b58ce4ea93a3c3b0addd33c7ca933 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 5 Dec 2024 17:03:05 +0100 Subject: [PATCH 30/31] Modifiy clone test docstrings to also refer to value check --- tests/test_clone.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/test_clone.py b/tests/test_clone.py index 0119343b..96fd692e 100644 --- a/tests/test_clone.py +++ b/tests/test_clone.py @@ -19,8 +19,11 @@ cheetah.VerticalCorrector, ], ) -def test_element_buffer_location(ElementClass): - """Test that the buffers of cloned elements do not share memory.""" +def test_element_buffer_contents_and_location(ElementClass): + """ + Test that the buffers of cloned elements have the same content while not sharing the + same memory location. + """ element = ElementClass(length=torch.tensor(1.0)) clone = element.clone() @@ -30,8 +33,11 @@ def test_element_buffer_location(ElementClass): @pytest.mark.parametrize("BeamClass", [cheetah.ParameterBeam, cheetah.ParticleBeam]) -def test_beam_buffer_location(BeamClass): - """Test that the buffers of clones beams do not share memory.""" +def test_beam_buffer_contents_and_location(BeamClass): + """ + Test that the buffers of cloned beams have the same content while not sharing the + same memory location. + """ beam = BeamClass.from_parameters() clone = beam.clone() From 247f63601ba75bc7539e87bb55abd1f80642eb4f Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Fri, 6 Dec 2024 12:52:10 +0100 Subject: [PATCH 31/31] Fix `Dipole` and `RBend` docstring w.r.t. `e1` and `e2` --- cheetah/accelerator/dipole.py | 4 ++-- cheetah/accelerator/rbend.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index 036c3477..2df26692 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -23,8 +23,8 @@ class Dipole(Element): :param length: Length in meters. :param angle: Deflection angle in rad. :param k1: Focussing strength in 1/m^-2. Only used with `"cheetah"` tracking method. - :param e1: The angle of inclination of the entrance face [rad]. - :param e2: The angle of inclination of the exit face [rad]. + :param dipole_e1: The angle of inclination of the entrance face in rad. + :param dipole_e2: The angle of inclination of the exit face in rad. :param tilt: Tilt of the magnet in x-y plane [rad]. :param gap: The magnet gap in meters. Note that in MAD and ELEGANT: HGAP = gap/2. :param gap_exit: The magnet gap at the exit in meters. Note that in MAD and diff --git a/cheetah/accelerator/rbend.py b/cheetah/accelerator/rbend.py index 2862688c..dc889abd 100644 --- a/cheetah/accelerator/rbend.py +++ b/cheetah/accelerator/rbend.py @@ -15,8 +15,8 @@ class RBend(Dipole): :param length: Length in meters. :param angle: Deflection angle in rad. :param k1: Focussing strength in 1/m^-2. - :param e1: The angle of inclination of the entrance face [rad]. - :param e2: The angle of inclination of the exit face [rad]. + :param rbend_e1: The angle of inclination of the entrance face in rad. + :param rbend_e2: The angle of inclination of the exit face in rad. :param gap: The magnet gap in meters. Note that in MAD and ELEGANT: HGAP = gap/2. :param gap_exit: The magnet gap at the exit in meters. Note that in MAD and ELEGANT: HGAP = gap/2. Only set if different from `gap`. Only used with