Skip to content

Commit 5db7e40

Browse files
authored
Merge pull request #157 from desy-ml/154-missing-batch-execution-compatibility-in-dipole-rotation-matrix-etc
Fix the batch execution for elements with tilts or misalignments
2 parents 76bb848 + 6b5181d commit 5db7e40

File tree

5 files changed

+215
-66
lines changed

5 files changed

+215
-66
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
### 🚨 Breaking Changes
66

7-
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116) (@jank324)
7+
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157) (@jank324, @cr-xu)
88

99
### 🚀 Features
1010

cheetah/accelerator.py

+43-47
Original file line numberDiff line numberDiff line change
@@ -379,13 +379,11 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
379379
energy=energy,
380380
)
381381

382-
if torch.all(self.misalignment[:, 0] == 0) and torch.all(
383-
self.misalignment[:, 1] == 0
384-
):
382+
if torch.all(self.misalignment == 0):
385383
return R
386384
else:
387-
R_exit, R_entry = misalignment_matrix(self.misalignment)
388-
R = torch.matmul(R_exit, torch.matmul(R, R_entry))
385+
R_entry, R_exit = misalignment_matrix(self.misalignment)
386+
R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry)
389387
return R
390388

391389
def broadcast(self, shape: Size) -> Element:
@@ -542,23 +540,21 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
542540
hx=self.hx,
543541
tilt=torch.zeros_like(self.length),
544542
energy=energy,
545-
)
543+
) # Tilt is applied after adding edges
546544
else: # Reduce to Thin-Corrector
547545
R = torch.eye(7, device=device, dtype=dtype).repeat(
548546
(*self.length.shape, 1, 1)
549547
)
550-
R[:, 0, 1] = self.length
551-
R[:, 2, 6] = self.angle
552-
R[:, 2, 3] = self.length
548+
R[..., 0, 1] = self.length
549+
R[..., 2, 6] = self.angle
550+
R[..., 2, 3] = self.length
553551

554552
# Apply fringe fields
555553
R = torch.matmul(R_exit, torch.matmul(R, R_enter))
556554
# Apply rotation for tilted magnets
557-
# TODO: Are we applying tilt twice (here and base_rmatrix)?
558555
R = torch.matmul(
559556
rotation_matrix(-self.tilt), torch.matmul(R, rotation_matrix(self.tilt))
560557
)
561-
562558
return R
563559

564560
def _transfer_map_enter(self) -> torch.Tensor:
@@ -576,8 +572,8 @@ def _transfer_map_enter(self) -> torch.Tensor:
576572
)
577573

578574
tm = torch.eye(7, device=device, dtype=dtype).repeat(*phi.shape, 1, 1)
579-
tm[:, 1, 0] = self.hx * torch.tan(self.e1)
580-
tm[:, 3, 2] = -self.hx * torch.tan(self.e1 - phi)
575+
tm[..., 1, 0] = self.hx * torch.tan(self.e1)
576+
tm[..., 3, 2] = -self.hx * torch.tan(self.e1 - phi)
581577

582578
return tm
583579

@@ -596,8 +592,8 @@ def _transfer_map_exit(self) -> torch.Tensor:
596592
)
597593

598594
tm = torch.eye(7, device=device, dtype=dtype).repeat(*phi.shape, 1, 1)
599-
tm[:, 1, 0] = self.hx * torch.tan(self.e2)
600-
tm[:, 3, 2] = -self.hx * torch.tan(self.e2 - phi)
595+
tm[..., 1, 0] = self.hx * torch.tan(self.e2)
596+
tm[..., 3, 2] = -self.hx * torch.tan(self.e2 - phi)
601597

602598
return tm
603599

@@ -1448,11 +1444,11 @@ def track(self, incoming: Beam) -> Beam:
14481444
copy_of_incoming = deepcopy(incoming)
14491445

14501446
if isinstance(incoming, ParameterBeam):
1451-
copy_of_incoming._mu[:, 0] -= self.misalignment[:, 0]
1452-
copy_of_incoming._mu[:, 2] -= self.misalignment[:, 1]
1447+
copy_of_incoming._mu[..., 0] -= self.misalignment[..., 0]
1448+
copy_of_incoming._mu[..., 2] -= self.misalignment[..., 1]
14531449
elif isinstance(incoming, ParticleBeam):
1454-
copy_of_incoming.particles[:, :, 0] -= self.misalignment[:, 0]
1455-
copy_of_incoming.particles[:, :, 1] -= self.misalignment[:, 1]
1450+
copy_of_incoming.particles[..., :, 0] -= self.misalignment[..., 0]
1451+
copy_of_incoming.particles[..., :, 1] -= self.misalignment[..., 1]
14561452

14571453
self.set_read_beam(copy_of_incoming)
14581454

@@ -1476,18 +1472,18 @@ def reading(self) -> torch.Tensor:
14761472
)
14771473
elif isinstance(read_beam, ParameterBeam):
14781474
transverse_mu = torch.stack(
1479-
[read_beam._mu[:, 0], read_beam._mu[:, 2]], dim=1
1475+
[read_beam._mu[..., 0], read_beam._mu[..., 2]], dim=-1
14801476
)
14811477
transverse_cov = torch.stack(
14821478
[
14831479
torch.stack(
1484-
[read_beam._cov[:, 0, 0], read_beam._cov[:, 0, 2]], dim=1
1480+
[read_beam._cov[..., 0, 0], read_beam._cov[..., 0, 2]], dim=-1
14851481
),
14861482
torch.stack(
1487-
[read_beam._cov[:, 2, 0], read_beam._cov[:, 2, 2]], dim=1
1483+
[read_beam._cov[..., 2, 0], read_beam._cov[..., 2, 2]], dim=-1
14881484
),
14891485
],
1490-
dim=1,
1486+
dim=-1,
14911487
)
14921488
dist = [
14931489
MultivariateNormal(
@@ -1767,9 +1763,9 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
17671763
)
17681764

17691765
tm = torch.eye(7, device=device, dtype=dtype).repeat((*energy.shape, 1, 1))
1770-
tm[:, 0, 1] = self.length
1771-
tm[:, 2, 3] = self.length
1772-
tm[:, 4, 5] = self.length * igamma2
1766+
tm[..., 0, 1] = self.length
1767+
tm[..., 2, 3] = self.length
1768+
tm[..., 4, 5] = self.length * igamma2
17731769

17741770
return tm
17751771

@@ -1866,38 +1862,38 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
18661862
r56 -= self.length / (beta * beta * gamma2)
18671863

18681864
R = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1))
1869-
R[:, 0, 0] = c**2
1870-
R[:, 0, 1] = c * s_k
1871-
R[:, 0, 2] = s * c
1872-
R[:, 0, 3] = s * s_k
1873-
R[:, 1, 0] = -self.k * s * c
1874-
R[:, 1, 1] = c**2
1875-
R[:, 1, 2] = -self.k * s**2
1876-
R[:, 1, 3] = s * c
1877-
R[:, 2, 0] = -s * c
1878-
R[:, 2, 1] = -s * s_k
1879-
R[:, 2, 2] = c**2
1880-
R[:, 2, 3] = c * s_k
1881-
R[:, 3, 0] = self.k * s**2
1882-
R[:, 3, 1] = -s * c
1883-
R[:, 3, 2] = -self.k * s * c
1884-
R[:, 3, 3] = c**2
1885-
R[:, 4, 5] = r56
1865+
R[..., 0, 0] = c**2
1866+
R[..., 0, 1] = c * s_k
1867+
R[..., 0, 2] = s * c
1868+
R[..., 0, 3] = s * s_k
1869+
R[..., 1, 0] = -self.k * s * c
1870+
R[..., 1, 1] = c**2
1871+
R[..., 1, 2] = -self.k * s**2
1872+
R[..., 1, 3] = s * c
1873+
R[..., 2, 0] = -s * c
1874+
R[..., 2, 1] = -s * s_k
1875+
R[..., 2, 2] = c**2
1876+
R[..., 2, 3] = c * s_k
1877+
R[..., 3, 0] = self.k * s**2
1878+
R[..., 3, 1] = -s * c
1879+
R[..., 3, 2] = -self.k * s * c
1880+
R[..., 3, 3] = c**2
1881+
R[..., 4, 5] = r56
18861882

18871883
R = R.real
18881884

1889-
if self.misalignment[0] == 0 and self.misalignment[1] == 0:
1885+
if torch.all(self.misalignment == 0):
18901886
return R
18911887
else:
1892-
R_exit, R_entry = misalignment_matrix(self.misalignment)
1893-
R = torch.matmul(R_exit, torch.matmul(R, R_entry))
1888+
R_entry, R_exit = misalignment_matrix(self.misalignment)
1889+
R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry)
18941890
return R
18951891

18961892
def broadcast(self, shape: Size) -> Element:
18971893
return self.__class__(
18981894
length=self.length.repeat(shape),
18991895
k=self.k.repeat(shape),
1900-
misalignment=self.misalignment.repeat(shape),
1896+
misalignment=self.misalignment.repeat((*shape, 1)),
19011897
name=self.name,
19021898
)
19031899

cheetah/track_methods.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ def rotation_matrix(angle: torch.Tensor) -> torch.Tensor:
2020
cs = torch.cos(angle)
2121
sn = torch.sin(angle)
2222

23-
tm = torch.eye(7, dtype=angle.dtype, device=angle.device)
24-
tm[0, 0] = cs
25-
tm[0, 2] = sn
26-
tm[1, 1] = cs
27-
tm[1, 3] = sn
28-
tm[2, 0] = -sn
29-
tm[2, 2] = cs
30-
tm[3, 1] = -sn
31-
tm[3, 3] = cs
23+
tm = torch.eye(7, dtype=angle.dtype, device=angle.device).repeat(*angle.shape, 1, 1)
24+
tm[..., 0, 0] = cs
25+
tm[..., 0, 2] = sn
26+
tm[..., 1, 1] = cs
27+
tm[..., 1, 3] = sn
28+
tm[..., 2, 0] = -sn
29+
tm[..., 2, 2] = cs
30+
tm[..., 3, 1] = -sn
31+
tm[..., 3, 3] = cs
3232

3333
return tm
3434

@@ -98,7 +98,9 @@ def base_rmatrix(
9898

9999
# Rotate the R matrix for skew / vertical magnets
100100
if torch.any(tilt != 0):
101-
R = torch.matmul(torch.matmul(rotation_matrix(-tilt), R), rotation_matrix(tilt))
101+
R = torch.einsum(
102+
"...ij,...jk,...kl->...il", rotation_matrix(-tilt), R, rotation_matrix(tilt)
103+
)
102104
return R
103105

104106

@@ -108,13 +110,14 @@ def misalignment_matrix(
108110
"""Shift the beam for tracking beam through misaligned elements"""
109111
device = misalignment.device
110112
dtype = misalignment.dtype
113+
batch_shape = misalignment.shape[:-1]
111114

112-
R_exit = torch.eye(7, device=device, dtype=dtype)
113-
R_exit[0, 6] = misalignment[0]
114-
R_exit[2, 6] = misalignment[1]
115+
R_exit = torch.eye(7, device=device, dtype=dtype).repeat(*batch_shape, 1, 1)
116+
R_exit[..., 0, 6] = misalignment[..., 0]
117+
R_exit[..., 2, 6] = misalignment[..., 1]
115118

116-
R_entry = torch.eye(7, device=device, dtype=dtype)
117-
R_entry[0, 6] = -misalignment[0]
118-
R_entry[2, 6] = -misalignment[1]
119+
R_entry = torch.eye(7, device=device, dtype=dtype).repeat(*batch_shape, 1, 1)
120+
R_entry[..., 0, 6] = -misalignment[..., 0]
121+
R_entry[..., 2, 6] = -misalignment[..., 1]
119122

120-
return R_exit, R_entry # TODO: This order is confusing, should be entry, exit
123+
return R_entry, R_exit

tests/test_compare_ocelot.py

+46
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,52 @@ def test_dipole_with_fringe_field():
101101
)
102102

103103

104+
def test_dipole_with_fringe_field_and_tilt():
105+
"""
106+
Test that the tracking results through a Cheeath `Dipole` element match those
107+
through an Oclet `Bend` element when there are fringe fields and tilt, and the
108+
e1 and e2 angles are set.
109+
"""
110+
# Cheetah
111+
bend_angle = np.pi / 6
112+
tilt_angle = np.pi / 4
113+
incoming_beam = cheetah.ParticleBeam.from_astra(
114+
"tests/resources/ACHIP_EA1_2021.1351.001"
115+
)
116+
cheetah_dipole = cheetah.Dipole(
117+
length=torch.tensor([1.0]),
118+
angle=torch.tensor([bend_angle]),
119+
fringe_integral=torch.tensor([0.1]),
120+
gap=torch.tensor([0.2]),
121+
tilt=torch.tensor([tilt_angle]),
122+
e1=torch.tensor([bend_angle / 2]),
123+
e2=torch.tensor([bend_angle / 2]),
124+
)
125+
outgoing_beam = cheetah_dipole(incoming_beam)
126+
127+
# Ocelot
128+
incoming_p_array = ocelot.astraBeam2particleArray(
129+
"tests/resources/ACHIP_EA1_2021.1351.001"
130+
)
131+
ocelot_bend = ocelot.Bend(
132+
l=1.0,
133+
angle=bend_angle,
134+
fint=0.1,
135+
gap=0.2,
136+
tilt=tilt_angle,
137+
e1=bend_angle / 2,
138+
e2=bend_angle / 2,
139+
)
140+
lattice = ocelot.MagneticLattice([ocelot_bend, ocelot.Marker()])
141+
navigator = ocelot.Navigator(lattice)
142+
_, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator)
143+
144+
assert np.allclose(
145+
outgoing_beam.particles[0, :, :6].cpu().numpy(),
146+
outgoing_p_array.rparticles.transpose(),
147+
)
148+
149+
104150
def test_aperture():
105151
"""
106152
Test that the tracking results through a Cheeath `Aperture` element match those

0 commit comments

Comments
 (0)