Skip to content

Commit ff3430b

Browse files
cr-xujank324
andauthored
Fix vectorization issue in rbend (#173)
* Fix rbend issue w.r.t. batch input; Update tests * Update changelog * Remove old code comment --------- Co-authored-by: Jan Kaiser <jan.kaiser@desy.de> Co-authored-by: Jan Kaiser <jan.kaiser.email@googlemail.com>
1 parent 57d287e commit ff3430b

File tree

4 files changed

+23
-15
lines changed

4 files changed

+23
-15
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
### 🚨 Breaking Changes
66

7-
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172) (@jank324, @cr-xu)
7+
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173) (@jank324, @cr-xu)
88

99
### 🚀 Features
1010

cheetah/accelerator/dipole.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
if fringe_integral_exit is None
8686
else torch.as_tensor(fringe_integral_exit, **factory_kwargs)
8787
)
88-
# Rectangular bend
88+
# Sector bend if not specified
8989
self.e1 = (
9090
torch.as_tensor(e1, **factory_kwargs)
9191
if e1 is not None

cheetah/accelerator/rbend.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,19 @@ def __init__(
5151
device=None,
5252
dtype=torch.float32,
5353
):
54-
angle = angle if angle is not None else torch.tensor(0.0)
55-
e1 = e1 if e1 is not None else torch.tensor(0.0)
56-
e2 = e2 if e2 is not None else torch.tensor(0.0)
57-
tilt = tilt if tilt is not None else torch.tensor(0.0)
58-
fringe_integral = (
59-
fringe_integral if fringe_integral is not None else torch.tensor(0.0)
54+
super().__init__(
55+
length=length,
56+
angle=angle,
57+
e1=e1,
58+
e2=e2,
59+
tilt=tilt,
60+
fringe_integral=fringe_integral,
61+
fringe_integral_exit=fringe_integral_exit,
62+
gap=gap,
63+
name=name,
64+
device=device,
65+
dtype=dtype,
6066
)
61-
# fringe_integral_exit is left out on purpose
62-
gap = gap if gap is not None else torch.tensor(0.0)
63-
64-
e1 = e1 + angle / 2
65-
e2 = e2 + angle / 2
6667

6768
super().__init__(
6869
length=length,
@@ -77,3 +78,7 @@ def __init__(
7778
device=device,
7879
dtype=dtype,
7980
)
81+
82+
# Rectangular bend
83+
self.e1 = self.e1 + self.angle / 2
84+
self.e2 = self.e2 + self.angle / 2

tests/test_compare_ocelot.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,10 @@ def test_rbend():
525525
"tests/resources/ACHIP_EA1_2021.1351.001"
526526
)
527527
cheetah_dipole = cheetah.RBend(
528-
length=torch.tensor([0.1]), angle=torch.tensor([0.2])
528+
length=torch.tensor([0.1]),
529+
angle=torch.tensor([0.2]),
530+
fringe_integral=torch.tensor([0.1]),
531+
gap=torch.tensor([0.2]),
529532
)
530533
cheetah_segment = cheetah.Segment(
531534
[
@@ -540,7 +543,7 @@ def test_rbend():
540543
incoming_p_array = ocelot.astraBeam2particleArray(
541544
"tests/resources/ACHIP_EA1_2021.1351.001", print_params=False
542545
)
543-
ocelot_rbend = ocelot.RBend(l=0.1, angle=0.2)
546+
ocelot_rbend = ocelot.RBend(l=0.1, angle=0.2, fint=0.1, gap=0.2)
544547
lattice = ocelot.MagneticLattice(
545548
[ocelot.Drift(l=0.1), ocelot_rbend, ocelot.Drift(l=0.1)]
546549
)

0 commit comments

Comments
 (0)