Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add clone method for beams and elements #289

Merged
merged 40 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0d2604b
Add cloning method for beams
Hespe Oct 28, 2024
8a883d8
Remove unnecessary comments
Hespe Oct 28, 2024
3e99a57
Add clone method for elements
Hespe Oct 28, 2024
e788b52
Remove deepcopy from plotting method
Hespe Oct 28, 2024
2754536
Test clone implementations
Hespe Oct 28, 2024
404f4c9
Add test for tracking with nonleave tensors
Hespe Oct 28, 2024
94c52df
Remove deepcopy from Screen
Hespe Oct 28, 2024
9aca0cd
Test plotting with nonleave tensors
Hespe Oct 28, 2024
4560d5d
Merge branch 'master' into clone-methods
Hespe Nov 21, 2024
936a1d6
Fix formatting
Hespe Nov 21, 2024
9fdab40
Mark Beam.clone as abstract
Hespe Nov 21, 2024
39324bf
Merge branch 'master' into clone-methods
jank324 Nov 22, 2024
b9e22d2
Write generic clone method for elements
Hespe Nov 22, 2024
980f42b
Add missing argument to RBend constructor
Hespe Nov 22, 2024
4572bb7
Add missing is_active feature for BPM
Hespe Nov 25, 2024
cfd0ef9
Add clone override for CustomTransferMap
Hespe Nov 25, 2024
bc914da
Apply suggestions from code review
jank324 Nov 25, 2024
b126040
Merge branch 'clone-methods' of github.com:Hespe/cheetah into clone-m…
Hespe Nov 25, 2024
bdeb961
Add name as a defining feature
Hespe Nov 25, 2024
81b45ff
Merge branch 'master' into clone-methods
Hespe Nov 25, 2024
94a204b
Fix issue where `name` was considered twice with LatticeJSON saving
jank324 Dec 4, 2024
ba6ff84
Make sure `is_active` is part of `defining_features` when it can be s…
jank324 Dec 4, 2024
9d417fb
Fix `RBend` vs `Dipole` `defining_features` issue
jank324 Dec 4, 2024
f227773
Merge branch 'master' into clone-methods
jank324 Dec 4, 2024
456cce0
Add test that finds `CustomTransferMap` cloning issue
jank324 Dec 4, 2024
ab8f3ae
Merge branch 'clone-methods' of https://github.com/Hespe/cheetah into…
jank324 Dec 4, 2024
21023b9
Fix `deepcopy` test failure resulting from merging `master`
jank324 Dec 4, 2024
cd5652a
Fix issue with `CustomTransferMap` saving and cloning
jank324 Dec 4, 2024
3bce621
Some cleanup
jank324 Dec 4, 2024
d6ec8fb
Fix misspelling of "leaf"
jank324 Dec 4, 2024
fdf4b9b
Another "leaf" spelling fix
jank324 Dec 4, 2024
63df47d
Fix misspelling of "feature"
jank324 Dec 4, 2024
0cf9c5c
Move non-lead tests to the same file as other differentiability tests
jank324 Dec 4, 2024
6d6586a
Update changelog because we forgot so far
jank324 Dec 4, 2024
05763fe
Add survival probability to clone implementation
Hespe Dec 5, 2024
ad2016b
Merge branch 'master' into clone-methods
jank324 Dec 5, 2024
b520eae
Modifiy clone test docstrings to also refer to value check
jank324 Dec 5, 2024
0dea009
Merge branch 'clone-methods' of https://github.com/Hespe/cheetah into…
jank324 Dec 5, 2024
10716ce
Merge branch 'master' into clone-methods
Hespe Dec 6, 2024
247f636
Fix `Dipole` and `RBend` docstring w.r.t. `e1` and `e2`
jank324 Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan)
- The way `dtype`s are determined is now more in line with PyTorch's conventions. This may cause different-than-expected `dtype`s in old code. (see #254) (@hespe, @jank324)
- `Beam.parameters()` no longer shadows `torch.nn.Module.parameters()`. The previously returned properties now need to be queried individually. (see #300) (@hespe)
- `e1` and `e2` in `Dipole` and `RBend` have been renamed and made more consistent between the different magnet types. They now have prefixes `dipole_` and `rbend_` respectively. (see #289) (@hespe, @jank324)
- The `_transfer_map` property of `CustomTransferMap` has been renamed to `predefined_transfer_map`. (see #289) (@hespe, @jank324)

### 🚀 Features

Expand All @@ -27,6 +29,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
- Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240, #278 #296) (@jp-ga, @cr-xu, @jank324)
- `Dipole` and `RBend` now take a focusing moment `k1` (see #235, #247) (@hespe)
- Implement a converter for lattice files imported from Elegant (see #222, #251, #273, #281) (@hespe, @jank324)
- `Beam` and `Element` objects now have a `.clone()` method to create a deep copy (see #289) (@hespe, @jank324)

### 🐛 Bug fixes

Expand Down
7 changes: 1 addition & 6 deletions cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No

@property
def defining_features(self) -> list[str]:
return super().defining_features + [
"x_max",
"y_max",
"shape",
"is_active",
]
return super().defining_features + ["x_max", "y_max", "shape", "is_active"]

def __repr__(self) -> str:
return (
Expand Down
5 changes: 2 additions & 3 deletions cheetah/accelerator/bpm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from copy import deepcopy
from typing import Optional

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -44,7 +43,7 @@ def track(self, incoming: Beam) -> Beam:
else:
raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}")

return deepcopy(incoming)
return incoming.clone()

def split(self, resolution: torch.Tensor) -> list[Element]:
return [self]
Expand All @@ -60,7 +59,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No

@property
def defining_features(self) -> list[str]:
return super().defining_features
return super().defining_features + ["is_active"]

def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={repr(self.name)})"
23 changes: 14 additions & 9 deletions cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,31 @@ class CustomTransferMap(Element):

def __init__(
self,
transfer_map: torch.Tensor,
predefined_transfer_map: torch.Tensor,
length: Optional[torch.Tensor] = None,
name: Optional[str] = None,
device=None,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([transfer_map, length], device, dtype)
device, dtype = verify_device_and_dtype(
[predefined_transfer_map, length], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

assert isinstance(transfer_map, torch.Tensor)
assert transfer_map.shape[-2:] == (7, 7)
assert isinstance(predefined_transfer_map, torch.Tensor)
assert predefined_transfer_map.shape[-2:] == (7, 7)

self.register_buffer(
"_transfer_map", torch.as_tensor(transfer_map, **factory_kwargs)
"predefined_transfer_map",
torch.as_tensor(predefined_transfer_map, **factory_kwargs),
)
self.register_buffer(
"length",
(
torch.as_tensor(length, **factory_kwargs)
if length is not None
else torch.zeros(transfer_map.shape[:-2], **factory_kwargs)
else torch.zeros(predefined_transfer_map.shape[:-2], **factory_kwargs)
),
)

Expand Down Expand Up @@ -83,21 +86,23 @@ def from_merging_elements(
)

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
return self._transfer_map
return self.predefined_transfer_map

@property
def is_skippable(self) -> bool:
return True

def __repr__(self):
return (
f"{self.__class__.__name__}(transfer_map={repr(self._transfer_map)}, "
f"{self.__class__.__name__}("
+ f"predefined_transfer_map={repr(self.predefined_transfer_map)}, "
+ f"length={repr(self.length)}, "
+ f"name={repr(self.name)})"
)

@property
def defining_features(self) -> list[str]:
return super().defining_features + ["transfer_map"]
return super().defining_features + ["length", "predefined_transfer_map"]

def split(self, resolution: torch.Tensor) -> list[Element]:
return [self]
Expand Down
82 changes: 49 additions & 33 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class Dipole(Element):
:param length: Length in meters.
:param angle: Deflection angle in rad.
:param k1: Focussing strength in 1/m^-2. Only used with `"cheetah"` tracking method.
:param e1: The angle of inclination of the entrance face [rad].
:param e2: The angle of inclination of the exit face [rad].
:param dipole_e1: The angle of inclination of the entrance face in rad.
:param dipole_e2: The angle of inclination of the exit face in rad.
:param tilt: Tilt of the magnet in x-y plane [rad].
:param gap: The magnet gap in meters. Note that in MAD and ELEGANT: HGAP = gap/2.
:param gap_exit: The magnet gap at the exit in meters. Note that in MAD and
Expand All @@ -49,8 +49,8 @@ def __init__(
length: torch.Tensor,
angle: Optional[torch.Tensor] = None,
k1: Optional[torch.Tensor] = None,
e1: Optional[torch.Tensor] = None,
e2: Optional[torch.Tensor] = None,
dipole_e1: Optional[torch.Tensor] = None,
dipole_e2: Optional[torch.Tensor] = None,
tilt: Optional[torch.Tensor] = None,
gap: Optional[torch.Tensor] = None,
gap_exit: Optional[torch.Tensor] = None,
Expand All @@ -68,8 +68,8 @@ def __init__(
length,
angle,
k1,
e1,
e2,
dipole_e1,
dipole_e2,
tilt,
gap,
gap_exit,
Expand All @@ -88,55 +88,55 @@ def __init__(
(
torch.as_tensor(angle, **factory_kwargs)
if angle is not None
else torch.zeros_like(self.length)
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
"k1",
(
torch.as_tensor(k1, **factory_kwargs)
if k1 is not None
else torch.zeros_like(self.length)
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
"e1",
"_e1",
(
torch.as_tensor(e1, **factory_kwargs)
if e1 is not None
else torch.zeros_like(self.length)
torch.as_tensor(dipole_e1, **factory_kwargs)
if dipole_e1 is not None
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
"e2",
"_e2",
(
torch.as_tensor(e2, **factory_kwargs)
if e2 is not None
else torch.zeros_like(self.length)
torch.as_tensor(dipole_e2, **factory_kwargs)
if dipole_e2 is not None
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
"fringe_integral",
(
torch.as_tensor(fringe_integral, **factory_kwargs)
if fringe_integral is not None
else torch.zeros_like(self.length)
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
"fringe_integral_exit",
(
self.fringe_integral
if fringe_integral_exit is None
else torch.as_tensor(fringe_integral_exit, **factory_kwargs)
else torch.tensor(fringe_integral_exit, **factory_kwargs)
),
)
self.register_buffer(
"gap",
(
torch.as_tensor(gap, **factory_kwargs)
if gap is not None
else torch.zeros_like(self.length)
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
Expand All @@ -152,7 +152,7 @@ def __init__(
(
torch.as_tensor(tilt, **factory_kwargs)
if tilt is not None
else torch.zeros_like(self.length)
else torch.tensor(0.0, **factory_kwargs)
),
)
self.fringe_at = fringe_at
Expand All @@ -163,6 +163,22 @@ def __init__(
def hx(self) -> torch.Tensor:
return torch.where(self.length == 0.0, 0.0, self.angle / self.length)

@property
def dipole_e1(self) -> torch.Tensor:
return self._e1

@dipole_e1.setter
def dipole_e1(self, value: torch.Tensor):
self._e1 = value

@property
def dipole_e2(self) -> torch.Tensor:
return self._e2

@dipole_e2.setter
def dipole_e2(self, value: torch.Tensor):
self._e2 = value

@property
def is_skippable(self) -> bool:
return self.tracking_method == "cheetah"
Expand Down Expand Up @@ -379,7 +395,7 @@ def _bmadx_fringe_linear(
:return: px, py final Bmad cannonical coordinates.
"""
g = self.angle / self.length
e = self.e1 * (location == "entrance") + self.e2 * (location == "exit")
e = self._e1 * (location == "entrance") + self._e2 * (location == "exit")
f_int = self.fringe_integral * (
location == "entrance"
) + self.fringe_integral_exit * (location == "exit")
Expand Down Expand Up @@ -432,18 +448,18 @@ def _transfer_map_enter(self) -> torch.Tensor:
device = self.length.device
dtype = self.length.dtype

sec_e = 1.0 / torch.cos(self.e1)
sec_e = 1.0 / torch.cos(self._e1)
phi = (
self.fringe_integral
* self.hx
* self.gap
* sec_e
* (1 + torch.sin(self.e1) ** 2)
* (1 + torch.sin(self._e1) ** 2)
)

tm = torch.eye(7, device=device, dtype=dtype).repeat(*phi.shape, 1, 1)
tm[..., 1, 0] = self.hx * torch.tan(self.e1)
tm[..., 3, 2] = -self.hx * torch.tan(self.e1 - phi)
tm[..., 1, 0] = self.hx * torch.tan(self._e1)
tm[..., 3, 2] = -self.hx * torch.tan(self._e1 - phi)

return tm

Expand All @@ -452,18 +468,18 @@ def _transfer_map_exit(self) -> torch.Tensor:
device = self.length.device
dtype = self.length.dtype

sec_e = 1.0 / torch.cos(self.e2)
sec_e = 1.0 / torch.cos(self._e2)
phi = (
self.fringe_integral_exit
* self.hx
* self.gap
* sec_e
* (1 + torch.sin(self.e2) ** 2)
* (1 + torch.sin(self._e2) ** 2)
)

tm = torch.eye(7, device=device, dtype=dtype).repeat(*phi.shape, 1, 1)
tm[..., 1, 0] = self.hx * torch.tan(self.e2)
tm[..., 3, 2] = -self.hx * torch.tan(self.e2 - phi)
tm[..., 1, 0] = self.hx * torch.tan(self._e2)
tm[..., 3, 2] = -self.hx * torch.tan(self._e2 - phi)

return tm

Expand All @@ -477,8 +493,8 @@ def __repr__(self):
f"{self.__class__.__name__}(length={repr(self.length)}, "
+ f"angle={repr(self.angle)}, "
+ f"k1={repr(self.k1)}, "
+ f"e1={repr(self.e1)},"
+ f"e2={repr(self.e2)},"
+ f"dipole_e1={repr(self.dipole_e1)},"
+ f"dipole_e2={repr(self.dipole_e2)},"
+ f"tilt={repr(self.tilt)},"
+ f"gap={repr(self.gap)},"
+ f"gap_exit={repr(self.gap_exit)},"
Expand All @@ -496,8 +512,8 @@ def defining_features(self) -> list[str]:
"length",
"angle",
"k1",
"e1",
"e2",
"dipole_e1",
"dipole_e2",
"tilt",
"gap",
"gap_exit",
Expand Down
16 changes: 15 additions & 1 deletion cheetah/accelerator/element.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Optional

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -110,7 +111,20 @@ def defining_features(self) -> list[str]:
NOTE: When overriding this property, make sure to call the super method and
extend the list it returns.
"""
return []
return ["name"]

def clone(self) -> "Element":
"""Create a copy of the element which does not share the underlying memory."""
return self.__class__(
**{
feature: (
getattr(self, feature).clone()
if isinstance(getattr(self, feature), torch.Tensor)
else deepcopy(getattr(self, feature))
)
for feature in self.defining_features
}
)

@abstractmethod
def split(self, resolution: torch.Tensor) -> list["Element"]:
Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/horizontal_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
(
torch.as_tensor(angle, **factory_kwargs)
if angle is not None
else torch.zeros_like(self.length)
else torch.tensor(0.0, **factory_kwargs)
),
)

Expand Down
4 changes: 2 additions & 2 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
(
torch.as_tensor(k1, **factory_kwargs)
if k1 is not None
else torch.zeros_like(self.length)
else torch.tensor(0.0, **factory_kwargs)
),
)
self.register_buffer(
Expand All @@ -71,7 +71,7 @@ def __init__(
(
torch.as_tensor(tilt, **factory_kwargs)
if tilt is not None
else torch.zeros_like(self.length)
else torch.tensor(0.0, **factory_kwargs)
),
)
self.num_steps = num_steps
Expand Down
Loading
Loading