Skip to content

Commit 66d5847

Browse files
authored
Merge pull request #289 from Hespe/clone-methods
Add `clone` method for beams and elements
2 parents 866b243 + 247f636 commit 66d5847

29 files changed

+427
-116
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
1212
- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan)
1313
- The way `dtype`s are determined is now more in line with PyTorch's conventions. This may cause different-than-expected `dtype`s in old code. (see #254) (@hespe, @jank324)
1414
- `Beam.parameters()` no longer shadows `torch.nn.Module.parameters()`. The previously returned properties now need to be queried individually. (see #300) (@hespe)
15+
- `e1` and `e2` in `Dipole` and `RBend` have been renamed and made more consistent between the different magnet types. They now have prefixes `dipole_` and `rbend_` respectively. (see #289) (@hespe, @jank324)
16+
- The `_transfer_map` property of `CustomTransferMap` has been renamed to `predefined_transfer_map`. (see #289) (@hespe, @jank324)
1517

1618
### 🚀 Features
1719

@@ -27,6 +29,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des
2729
- Add `TransverseDeflectingCavity` element (following the Bmad-X implementation) (see #240, #278 #296) (@jp-ga, @cr-xu, @jank324)
2830
- `Dipole` and `RBend` now take a focusing moment `k1` (see #235, #247) (@hespe)
2931
- Implement a converter for lattice files imported from Elegant (see #222, #251, #273, #281) (@hespe, @jank324)
32+
- `Beam` and `Element` objects now have a `.clone()` method to create a deep copy (see #289) (@hespe, @jank324)
3033

3134
### 🐛 Bug fixes
3235

cheetah/accelerator/aperture.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No
126126

127127
@property
128128
def defining_features(self) -> list[str]:
129-
return super().defining_features + [
130-
"x_max",
131-
"y_max",
132-
"shape",
133-
"is_active",
134-
]
129+
return super().defining_features + ["x_max", "y_max", "shape", "is_active"]
135130

136131
def __repr__(self) -> str:
137132
return (

cheetah/accelerator/bpm.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from copy import deepcopy
21
from typing import Optional
32

43
import matplotlib.pyplot as plt
@@ -44,7 +43,7 @@ def track(self, incoming: Beam) -> Beam:
4443
else:
4544
raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}")
4645

47-
return deepcopy(incoming)
46+
return incoming.clone()
4847

4948
def split(self, resolution: torch.Tensor) -> list[Element]:
5049
return [self]
@@ -60,7 +59,7 @@ def plot(self, ax: plt.Axes, s: float, vector_idx: Optional[tuple] = None) -> No
6059

6160
@property
6261
def defining_features(self) -> list[str]:
63-
return super().defining_features
62+
return super().defining_features + ["is_active"]
6463

6564
def __repr__(self) -> str:
6665
return f"{self.__class__.__name__}(name={repr(self.name)})"

cheetah/accelerator/custom_transfer_map.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,31 @@ class CustomTransferMap(Element):
1818

1919
def __init__(
2020
self,
21-
transfer_map: torch.Tensor,
21+
predefined_transfer_map: torch.Tensor,
2222
length: Optional[torch.Tensor] = None,
2323
name: Optional[str] = None,
2424
device=None,
2525
dtype=None,
2626
) -> None:
27-
device, dtype = verify_device_and_dtype([transfer_map, length], device, dtype)
27+
device, dtype = verify_device_and_dtype(
28+
[predefined_transfer_map, length], device, dtype
29+
)
2830
factory_kwargs = {"device": device, "dtype": dtype}
2931
super().__init__(name=name)
3032

31-
assert isinstance(transfer_map, torch.Tensor)
32-
assert transfer_map.shape[-2:] == (7, 7)
33+
assert isinstance(predefined_transfer_map, torch.Tensor)
34+
assert predefined_transfer_map.shape[-2:] == (7, 7)
3335

3436
self.register_buffer(
35-
"_transfer_map", torch.as_tensor(transfer_map, **factory_kwargs)
37+
"predefined_transfer_map",
38+
torch.as_tensor(predefined_transfer_map, **factory_kwargs),
3639
)
3740
self.register_buffer(
3841
"length",
3942
(
4043
torch.as_tensor(length, **factory_kwargs)
4144
if length is not None
42-
else torch.zeros(transfer_map.shape[:-2], **factory_kwargs)
45+
else torch.zeros(predefined_transfer_map.shape[:-2], **factory_kwargs)
4346
),
4447
)
4548

@@ -83,21 +86,23 @@ def from_merging_elements(
8386
)
8487

8588
def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
86-
return self._transfer_map
89+
return self.predefined_transfer_map
8790

8891
@property
8992
def is_skippable(self) -> bool:
9093
return True
9194

9295
def __repr__(self):
9396
return (
94-
f"{self.__class__.__name__}(transfer_map={repr(self._transfer_map)}, "
97+
f"{self.__class__.__name__}("
98+
+ f"predefined_transfer_map={repr(self.predefined_transfer_map)}, "
9599
+ f"length={repr(self.length)}, "
96100
+ f"name={repr(self.name)})"
97101
)
98102

103+
@property
99104
def defining_features(self) -> list[str]:
100-
return super().defining_features + ["transfer_map"]
105+
return super().defining_features + ["length", "predefined_transfer_map"]
101106

102107
def split(self, resolution: torch.Tensor) -> list[Element]:
103108
return [self]

cheetah/accelerator/dipole.py

+49-33
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class Dipole(Element):
2323
:param length: Length in meters.
2424
:param angle: Deflection angle in rad.
2525
:param k1: Focussing strength in 1/m^-2. Only used with `"cheetah"` tracking method.
26-
:param e1: The angle of inclination of the entrance face [rad].
27-
:param e2: The angle of inclination of the exit face [rad].
26+
:param dipole_e1: The angle of inclination of the entrance face in rad.
27+
:param dipole_e2: The angle of inclination of the exit face in rad.
2828
:param tilt: Tilt of the magnet in x-y plane [rad].
2929
:param gap: The magnet gap in meters. Note that in MAD and ELEGANT: HGAP = gap/2.
3030
:param gap_exit: The magnet gap at the exit in meters. Note that in MAD and
@@ -49,8 +49,8 @@ def __init__(
4949
length: torch.Tensor,
5050
angle: Optional[torch.Tensor] = None,
5151
k1: Optional[torch.Tensor] = None,
52-
e1: Optional[torch.Tensor] = None,
53-
e2: Optional[torch.Tensor] = None,
52+
dipole_e1: Optional[torch.Tensor] = None,
53+
dipole_e2: Optional[torch.Tensor] = None,
5454
tilt: Optional[torch.Tensor] = None,
5555
gap: Optional[torch.Tensor] = None,
5656
gap_exit: Optional[torch.Tensor] = None,
@@ -68,8 +68,8 @@ def __init__(
6868
length,
6969
angle,
7070
k1,
71-
e1,
72-
e2,
71+
dipole_e1,
72+
dipole_e2,
7373
tilt,
7474
gap,
7575
gap_exit,
@@ -88,55 +88,55 @@ def __init__(
8888
(
8989
torch.as_tensor(angle, **factory_kwargs)
9090
if angle is not None
91-
else torch.zeros_like(self.length)
91+
else torch.tensor(0.0, **factory_kwargs)
9292
),
9393
)
9494
self.register_buffer(
9595
"k1",
9696
(
9797
torch.as_tensor(k1, **factory_kwargs)
9898
if k1 is not None
99-
else torch.zeros_like(self.length)
99+
else torch.tensor(0.0, **factory_kwargs)
100100
),
101101
)
102102
self.register_buffer(
103-
"e1",
103+
"_e1",
104104
(
105-
torch.as_tensor(e1, **factory_kwargs)
106-
if e1 is not None
107-
else torch.zeros_like(self.length)
105+
torch.as_tensor(dipole_e1, **factory_kwargs)
106+
if dipole_e1 is not None
107+
else torch.tensor(0.0, **factory_kwargs)
108108
),
109109
)
110110
self.register_buffer(
111-
"e2",
111+
"_e2",
112112
(
113-
torch.as_tensor(e2, **factory_kwargs)
114-
if e2 is not None
115-
else torch.zeros_like(self.length)
113+
torch.as_tensor(dipole_e2, **factory_kwargs)
114+
if dipole_e2 is not None
115+
else torch.tensor(0.0, **factory_kwargs)
116116
),
117117
)
118118
self.register_buffer(
119119
"fringe_integral",
120120
(
121121
torch.as_tensor(fringe_integral, **factory_kwargs)
122122
if fringe_integral is not None
123-
else torch.zeros_like(self.length)
123+
else torch.tensor(0.0, **factory_kwargs)
124124
),
125125
)
126126
self.register_buffer(
127127
"fringe_integral_exit",
128128
(
129129
self.fringe_integral
130130
if fringe_integral_exit is None
131-
else torch.as_tensor(fringe_integral_exit, **factory_kwargs)
131+
else torch.tensor(fringe_integral_exit, **factory_kwargs)
132132
),
133133
)
134134
self.register_buffer(
135135
"gap",
136136
(
137137
torch.as_tensor(gap, **factory_kwargs)
138138
if gap is not None
139-
else torch.zeros_like(self.length)
139+
else torch.tensor(0.0, **factory_kwargs)
140140
),
141141
)
142142
self.register_buffer(
@@ -152,7 +152,7 @@ def __init__(
152152
(
153153
torch.as_tensor(tilt, **factory_kwargs)
154154
if tilt is not None
155-
else torch.zeros_like(self.length)
155+
else torch.tensor(0.0, **factory_kwargs)
156156
),
157157
)
158158
self.fringe_at = fringe_at
@@ -163,6 +163,22 @@ def __init__(
163163
def hx(self) -> torch.Tensor:
164164
return torch.where(self.length == 0.0, 0.0, self.angle / self.length)
165165

166+
@property
167+
def dipole_e1(self) -> torch.Tensor:
168+
return self._e1
169+
170+
@dipole_e1.setter
171+
def dipole_e1(self, value: torch.Tensor):
172+
self._e1 = value
173+
174+
@property
175+
def dipole_e2(self) -> torch.Tensor:
176+
return self._e2
177+
178+
@dipole_e2.setter
179+
def dipole_e2(self, value: torch.Tensor):
180+
self._e2 = value
181+
166182
@property
167183
def is_skippable(self) -> bool:
168184
return self.tracking_method == "cheetah"
@@ -379,7 +395,7 @@ def _bmadx_fringe_linear(
379395
:return: px, py final Bmad cannonical coordinates.
380396
"""
381397
g = self.angle / self.length
382-
e = self.e1 * (location == "entrance") + self.e2 * (location == "exit")
398+
e = self._e1 * (location == "entrance") + self._e2 * (location == "exit")
383399
f_int = self.fringe_integral * (
384400
location == "entrance"
385401
) + self.fringe_integral_exit * (location == "exit")
@@ -432,18 +448,18 @@ def _transfer_map_enter(self) -> torch.Tensor:
432448
device = self.length.device
433449
dtype = self.length.dtype
434450

435-
sec_e = 1.0 / torch.cos(self.e1)
451+
sec_e = 1.0 / torch.cos(self._e1)
436452
phi = (
437453
self.fringe_integral
438454
* self.hx
439455
* self.gap
440456
* sec_e
441-
* (1 + torch.sin(self.e1) ** 2)
457+
* (1 + torch.sin(self._e1) ** 2)
442458
)
443459

444460
tm = torch.eye(7, device=device, dtype=dtype).repeat(*phi.shape, 1, 1)
445-
tm[..., 1, 0] = self.hx * torch.tan(self.e1)
446-
tm[..., 3, 2] = -self.hx * torch.tan(self.e1 - phi)
461+
tm[..., 1, 0] = self.hx * torch.tan(self._e1)
462+
tm[..., 3, 2] = -self.hx * torch.tan(self._e1 - phi)
447463

448464
return tm
449465

@@ -452,18 +468,18 @@ def _transfer_map_exit(self) -> torch.Tensor:
452468
device = self.length.device
453469
dtype = self.length.dtype
454470

455-
sec_e = 1.0 / torch.cos(self.e2)
471+
sec_e = 1.0 / torch.cos(self._e2)
456472
phi = (
457473
self.fringe_integral_exit
458474
* self.hx
459475
* self.gap
460476
* sec_e
461-
* (1 + torch.sin(self.e2) ** 2)
477+
* (1 + torch.sin(self._e2) ** 2)
462478
)
463479

464480
tm = torch.eye(7, device=device, dtype=dtype).repeat(*phi.shape, 1, 1)
465-
tm[..., 1, 0] = self.hx * torch.tan(self.e2)
466-
tm[..., 3, 2] = -self.hx * torch.tan(self.e2 - phi)
481+
tm[..., 1, 0] = self.hx * torch.tan(self._e2)
482+
tm[..., 3, 2] = -self.hx * torch.tan(self._e2 - phi)
467483

468484
return tm
469485

@@ -477,8 +493,8 @@ def __repr__(self):
477493
f"{self.__class__.__name__}(length={repr(self.length)}, "
478494
+ f"angle={repr(self.angle)}, "
479495
+ f"k1={repr(self.k1)}, "
480-
+ f"e1={repr(self.e1)},"
481-
+ f"e2={repr(self.e2)},"
496+
+ f"dipole_e1={repr(self.dipole_e1)},"
497+
+ f"dipole_e2={repr(self.dipole_e2)},"
482498
+ f"tilt={repr(self.tilt)},"
483499
+ f"gap={repr(self.gap)},"
484500
+ f"gap_exit={repr(self.gap_exit)},"
@@ -496,8 +512,8 @@ def defining_features(self) -> list[str]:
496512
"length",
497513
"angle",
498514
"k1",
499-
"e1",
500-
"e2",
515+
"dipole_e1",
516+
"dipole_e2",
501517
"tilt",
502518
"gap",
503519
"gap_exit",

cheetah/accelerator/element.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from copy import deepcopy
23
from typing import Optional
34

45
import matplotlib.pyplot as plt
@@ -110,7 +111,20 @@ def defining_features(self) -> list[str]:
110111
NOTE: When overriding this property, make sure to call the super method and
111112
extend the list it returns.
112113
"""
113-
return []
114+
return ["name"]
115+
116+
def clone(self) -> "Element":
117+
"""Create a copy of the element which does not share the underlying memory."""
118+
return self.__class__(
119+
**{
120+
feature: (
121+
getattr(self, feature).clone()
122+
if isinstance(getattr(self, feature), torch.Tensor)
123+
else deepcopy(getattr(self, feature))
124+
)
125+
for feature in self.defining_features
126+
}
127+
)
114128

115129
@abstractmethod
116130
def split(self, resolution: torch.Tensor) -> list["Element"]:

cheetah/accelerator/horizontal_corrector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
(
4545
torch.as_tensor(angle, **factory_kwargs)
4646
if angle is not None
47-
else torch.zeros_like(self.length)
47+
else torch.tensor(0.0, **factory_kwargs)
4848
),
4949
)
5050

cheetah/accelerator/quadrupole.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
(
5656
torch.as_tensor(k1, **factory_kwargs)
5757
if k1 is not None
58-
else torch.zeros_like(self.length)
58+
else torch.tensor(0.0, **factory_kwargs)
5959
),
6060
)
6161
self.register_buffer(
@@ -71,7 +71,7 @@ def __init__(
7171
(
7272
torch.as_tensor(tilt, **factory_kwargs)
7373
if tilt is not None
74-
else torch.zeros_like(self.length)
74+
else torch.tensor(0.0, **factory_kwargs)
7575
),
7676
)
7777
self.num_steps = num_steps

0 commit comments

Comments
 (0)