Skip to content

Commit 47a5497

Browse files
committed
Test that n-dimensional inputs work
1 parent dc4e5d7 commit 47a5497

File tree

4 files changed

+373
-207
lines changed

4 files changed

+373
-207
lines changed

cheetah/accelerator.py

+63-78
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,9 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
295295
beta = torch.sqrt(1 - igamma2)
296296

297297
tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1))
298-
tm[:, 0, 1] = self.length
299-
tm[:, 2, 3] = self.length
300-
tm[:, 4, 5] = -self.length / beta**2 * igamma2
298+
tm[..., 0, 1] = self.length
299+
tm[..., 2, 3] = self.length
300+
tm[..., 4, 5] = -self.length / beta**2 * igamma2
301301

302302
return tm
303303

@@ -379,7 +379,9 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
379379
energy=energy,
380380
)
381381

382-
if all(self.misalignment[:, 0] == 0) and all(self.misalignment[:, 1] == 0):
382+
if torch.all(self.misalignment[:, 0] == 0) and torch.all(
383+
self.misalignment[:, 1] == 0
384+
):
383385
return R
384386
else:
385387
R_exit, R_entry = misalignment_matrix(self.misalignment)
@@ -750,10 +752,10 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
750752
beta = torch.sqrt(1 - igamma2)
751753

752754
tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1))
753-
tm[:, 0, 1] = self.length
754-
tm[:, 1, 6] = self.angle
755-
tm[:, 2, 3] = self.length
756-
tm[:, 4, 5] = -self.length / beta**2 * igamma2
755+
tm[..., 0, 1] = self.length
756+
tm[..., 1, 6] = self.angle
757+
tm[..., 2, 3] = self.length
758+
tm[..., 4, 5] = -self.length / beta**2 * igamma2
757759

758760
return tm
759761

@@ -840,10 +842,10 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
840842
beta = torch.sqrt(1 - igamma2)
841843

842844
tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1))
843-
tm[:, 0, 1] = self.length
844-
tm[:, 2, 3] = self.length
845-
tm[:, 3, 6] = self.angle
846-
tm[:, 4, 5] = -self.length / beta**2 * igamma2
845+
tm[..., 0, 1] = self.length
846+
tm[..., 2, 3] = self.length
847+
tm[..., 3, 6] = self.angle
848+
tm[..., 4, 5] = -self.length / beta**2 * igamma2
847849
return tm
848850

849851
def broadcast(self, shape: Size) -> Element:
@@ -940,29 +942,11 @@ def is_skippable(self) -> bool:
940942
return not self.is_active
941943

942944
def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
943-
device = self.length.device
944-
dtype = self.length.dtype
945-
946-
# TODO: This feels weird because I'm computing the all transfer maps for both
947-
# cases, but only using one of them. Maybe there is a better way to do this.
948-
# ... or am I?
949-
tm = torch.empty((*self.length.shape, 7, 7), device=device, dtype=dtype)
950-
if any(self.voltage > 0):
951-
tm[self.voltage > 0] = self._cavity_rmatrix(energy[self.voltage > 0])
952-
if any(self.voltage <= 0):
953-
tm[self.voltage <= 0] = base_rmatrix(
954-
length=self.length[self.voltage <= 0],
955-
k1=torch.zeros_like(
956-
self.length[self.voltage <= 0], device=device, dtype=dtype
957-
),
958-
hx=torch.zeros_like(
959-
self.length[self.voltage <= 0], device=device, dtype=dtype
960-
),
961-
tilt=torch.zeros_like(
962-
self.length[self.voltage <= 0], device=device, dtype=dtype
963-
),
964-
energy=energy[self.voltage <= 0],
965-
)
945+
# There used to be a check for voltage > 0 here, where the cavity transfer map
946+
# was only computed for the elements with voltage > 0 and a basermatrix was
947+
# used otherwise. This was removed because it was causing issues with the
948+
# vectorisation, but I am not sure it is okay to remove.
949+
tm = self._cavity_rmatrix(energy)
966950

967951
return tm
968952

@@ -990,11 +974,12 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
990974
igamma2 = torch.full_like(self.length, 0.0)
991975
g0 = torch.full_like(self.length, 1e10)
992976

993-
g0[incoming.energy != 0] = incoming.energy / electron_mass_eV.to(
977+
mask = incoming.energy != 0
978+
g0[mask] = incoming.energy[mask] / electron_mass_eV.to(
994979
device=device, dtype=dtype
995980
)
996-
igamma2[incoming.energy != 0] = 1 / g0[incoming.energy != 0] ** 2
997-
beta0[incoming.energy != 0] = torch.sqrt(1 - igamma2[incoming.energy != 0])
981+
igamma2[mask] = 1 / g0[mask] ** 2
982+
beta0[mask] = torch.sqrt(1 - igamma2[mask])
998983

999984
phi = torch.deg2rad(self.phase)
1000985

@@ -1012,22 +997,22 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
1012997
T556 = torch.full_like(self.length, 0.0)
1013998
T555 = torch.full_like(self.length, 0.0)
1014999

1015-
if any(incoming.energy + delta_energy > 0):
1000+
if torch.any(incoming.energy + delta_energy > 0):
10161001
k = 2 * torch.pi * self.frequency / constants.speed_of_light
10171002
outgoing_energy = incoming.energy + delta_energy
10181003
g1 = outgoing_energy / electron_mass_eV
10191004
beta1 = torch.sqrt(1 - 1 / g1**2)
10201005

10211006
if isinstance(incoming, ParameterBeam):
1022-
outgoing_mu[:, 5] = incoming._mu[:, 5] * incoming.energy * beta0 / (
1007+
outgoing_mu[..., 5] = incoming._mu[..., 5] * incoming.energy * beta0 / (
10231008
outgoing_energy * beta1
10241009
) + self.voltage * beta0 / (outgoing_energy * beta1) * (
1025-
torch.cos(-incoming._mu[:, 4] * beta0 * k + phi) - torch.cos(phi)
1010+
torch.cos(-incoming._mu[..., 4] * beta0 * k + phi) - torch.cos(phi)
10261011
)
1027-
outgoing_cov[:, 5, 5] = incoming._cov[:, 5, 5]
1012+
outgoing_cov[..., 5, 5] = incoming._cov[..., 5, 5]
10281013
else: # ParticleBeam
1029-
outgoing_particles[:, :, 5] = incoming.particles[
1030-
:, :, 5
1014+
outgoing_particles[..., 5] = incoming.particles[
1015+
..., 5
10311016
] * incoming.energy.unsqueeze(-1) * beta0.unsqueeze(-1) / (
10321017
outgoing_energy.unsqueeze(-1) * beta1.unsqueeze(-1)
10331018
) + self.voltage.unsqueeze(
@@ -1038,7 +1023,7 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
10381023
outgoing_energy.unsqueeze(-1) * beta1.unsqueeze(-1)
10391024
) * (
10401025
torch.cos(
1041-
incoming.particles[:, :, 4]
1026+
incoming.particles[..., 4]
10421027
* beta0.unsqueeze(-1)
10431028
* k.unsqueeze(-1)
10441029
+ phi.unsqueeze(-1)
@@ -1047,7 +1032,7 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
10471032
)
10481033

10491034
dgamma = self.voltage / electron_mass_eV
1050-
if any(delta_energy > 0):
1035+
if torch.any(delta_energy > 0):
10511036
T566 = (
10521037
self.length
10531038
* (beta0**3 * g0**3 - beta1**3 * g1**3)
@@ -1086,29 +1071,29 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
10861071
)
10871072

10881073
if isinstance(incoming, ParameterBeam):
1089-
outgoing_mu[:, 4] = outgoing_mu[:, 4] + (
1090-
T566 * incoming._mu[:, 5] ** 2
1091-
+ T556 * incoming._mu[:, 4] * incoming._mu[:, 5]
1092-
+ T555 * incoming._mu[:, 4] ** 2
1074+
outgoing_mu[..., 4] = outgoing_mu[..., 4] + (
1075+
T566 * incoming._mu[..., 5] ** 2
1076+
+ T556 * incoming._mu[..., 4] * incoming._mu[..., 5]
1077+
+ T555 * incoming._mu[..., 4] ** 2
10931078
)
1094-
outgoing_cov[:, 4, 4] = (
1095-
T566 * incoming._cov[:, 5, 5] ** 2
1096-
+ T556 * incoming._cov[:, 4, 5] * incoming._cov[:, 5, 5]
1097-
+ T555 * incoming._cov[:, 4, 4] ** 2
1079+
outgoing_cov[..., 4, 4] = (
1080+
T566 * incoming._cov[..., 5, 5] ** 2
1081+
+ T556 * incoming._cov[..., 4, 5] * incoming._cov[..., 5, 5]
1082+
+ T555 * incoming._cov[..., 4, 4] ** 2
10981083
)
1099-
outgoing_cov[:, 4, 5] = (
1100-
T566 * incoming._cov[:, 5, 5] ** 2
1101-
+ T556 * incoming._cov[:, 4, 5] * incoming._cov[:, 5, 5]
1102-
+ T555 * incoming._cov[:, 4, 4] ** 2
1084+
outgoing_cov[..., 4, 5] = (
1085+
T566 * incoming._cov[..., 5, 5] ** 2
1086+
+ T556 * incoming._cov[..., 4, 5] * incoming._cov[..., 5, 5]
1087+
+ T555 * incoming._cov[..., 4, 4] ** 2
11031088
)
1104-
outgoing_cov[:, 5, 4] = outgoing_cov[:, 4, 5]
1089+
outgoing_cov[..., 5, 4] = outgoing_cov[..., 4, 5]
11051090
else: # ParticleBeam
1106-
outgoing_particles[:, :, 4] = outgoing_particles[:, :, 4] + (
1107-
T566.unsqueeze(-1) * incoming.particles[:, :, 5] ** 2
1091+
outgoing_particles[..., 4] = outgoing_particles[..., 4] + (
1092+
T566.unsqueeze(-1) * incoming.particles[..., 5] ** 2
11081093
+ T556.unsqueeze(-1)
1109-
* incoming.particles[:, :, 4]
1110-
* incoming.particles[:, :, 5]
1111-
+ T555.unsqueeze(-1) * incoming.particles[:, :, 4] ** 2
1094+
* incoming.particles[..., 4]
1095+
* incoming.particles[..., 5]
1096+
+ T555.unsqueeze(-1) * incoming.particles[..., 4] ** 2
11121097
)
11131098

11141099
if isinstance(incoming, ParameterBeam):
@@ -1143,7 +1128,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
11431128
Ei = energy / electron_mass_eV
11441129
Ef = (energy + delta_energy) / electron_mass_eV
11451130
Ep = (Ef - Ei) / self.length # Derivative of the energy
1146-
assert all(Ei > 0), "Initial energy must be larger than 0"
1131+
assert torch.all(Ei > 0), "Initial energy must be larger than 0"
11471132

11481133
alpha = torch.sqrt(eta / 8) / torch.cos(phi) * torch.log(Ef / Ei)
11491134

@@ -1179,7 +1164,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
11791164

11801165
k = 2 * torch.pi * self.frequency / torch.tensor(constants.speed_of_light)
11811166
r55_cor = 0.0
1182-
if any((self.voltage != 0) & (energy != 0)): # TODO: Do we need this if?
1167+
if torch.any((self.voltage != 0) & (energy != 0)): # TODO: Do we need this if?
11831168
beta0 = torch.sqrt(1 - 1 / Ei**2)
11841169
beta1 = torch.sqrt(1 - 1 / Ef**2)
11851170

@@ -1201,18 +1186,18 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
12011186
r65 = k * torch.sin(phi) * self.voltage / (Ef * beta1 * electron_mass_eV)
12021187

12031188
R = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1))
1204-
R[:, 0, 0] = r11
1205-
R[:, 0, 1] = r12
1206-
R[:, 1, 0] = r21
1207-
R[:, 1, 1] = r22
1208-
R[:, 2, 2] = r11
1209-
R[:, 2, 3] = r12
1210-
R[:, 3, 2] = r21
1211-
R[:, 3, 3] = r22
1212-
R[:, 4, 4] = 1 + r55_cor
1213-
R[:, 4, 5] = r56
1214-
R[:, 5, 4] = r65
1215-
R[:, 5, 5] = r66
1189+
R[..., 0, 0] = r11
1190+
R[..., 0, 1] = r12
1191+
R[..., 1, 0] = r21
1192+
R[..., 1, 1] = r22
1193+
R[..., 2, 2] = r11
1194+
R[..., 2, 3] = r12
1195+
R[..., 3, 2] = r21
1196+
R[..., 3, 3] = r22
1197+
R[..., 4, 4] = 1 + r55_cor
1198+
R[..., 4, 5] = r56
1199+
R[..., 5, 4] = r65
1200+
R[..., 5, 5] = r66
12161201

12171202
return R
12181203

0 commit comments

Comments
 (0)