@@ -295,9 +295,9 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
295
295
beta = torch .sqrt (1 - igamma2 )
296
296
297
297
tm = torch .eye (7 , device = device , dtype = dtype ).repeat ((* self .length .shape , 1 , 1 ))
298
- tm [: , 0 , 1 ] = self .length
299
- tm [: , 2 , 3 ] = self .length
300
- tm [: , 4 , 5 ] = - self .length / beta ** 2 * igamma2
298
+ tm [... , 0 , 1 ] = self .length
299
+ tm [... , 2 , 3 ] = self .length
300
+ tm [... , 4 , 5 ] = - self .length / beta ** 2 * igamma2
301
301
302
302
return tm
303
303
@@ -379,7 +379,9 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
379
379
energy = energy ,
380
380
)
381
381
382
- if all (self .misalignment [:, 0 ] == 0 ) and all (self .misalignment [:, 1 ] == 0 ):
382
+ if torch .all (self .misalignment [:, 0 ] == 0 ) and torch .all (
383
+ self .misalignment [:, 1 ] == 0
384
+ ):
383
385
return R
384
386
else :
385
387
R_exit , R_entry = misalignment_matrix (self .misalignment )
@@ -750,10 +752,10 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
750
752
beta = torch .sqrt (1 - igamma2 )
751
753
752
754
tm = torch .eye (7 , device = device , dtype = dtype ).repeat ((* self .length .shape , 1 , 1 ))
753
- tm [: , 0 , 1 ] = self .length
754
- tm [: , 1 , 6 ] = self .angle
755
- tm [: , 2 , 3 ] = self .length
756
- tm [: , 4 , 5 ] = - self .length / beta ** 2 * igamma2
755
+ tm [... , 0 , 1 ] = self .length
756
+ tm [... , 1 , 6 ] = self .angle
757
+ tm [... , 2 , 3 ] = self .length
758
+ tm [... , 4 , 5 ] = - self .length / beta ** 2 * igamma2
757
759
758
760
return tm
759
761
@@ -840,10 +842,10 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
840
842
beta = torch .sqrt (1 - igamma2 )
841
843
842
844
tm = torch .eye (7 , device = device , dtype = dtype ).repeat ((* self .length .shape , 1 , 1 ))
843
- tm [: , 0 , 1 ] = self .length
844
- tm [: , 2 , 3 ] = self .length
845
- tm [: , 3 , 6 ] = self .angle
846
- tm [: , 4 , 5 ] = - self .length / beta ** 2 * igamma2
845
+ tm [... , 0 , 1 ] = self .length
846
+ tm [... , 2 , 3 ] = self .length
847
+ tm [... , 3 , 6 ] = self .angle
848
+ tm [... , 4 , 5 ] = - self .length / beta ** 2 * igamma2
847
849
return tm
848
850
849
851
def broadcast (self , shape : Size ) -> Element :
@@ -940,29 +942,11 @@ def is_skippable(self) -> bool:
940
942
return not self .is_active
941
943
942
944
def transfer_map (self , energy : torch .Tensor ) -> torch .Tensor :
943
- device = self .length .device
944
- dtype = self .length .dtype
945
-
946
- # TODO: This feels weird because I'm computing the all transfer maps for both
947
- # cases, but only using one of them. Maybe there is a better way to do this.
948
- # ... or am I?
949
- tm = torch .empty ((* self .length .shape , 7 , 7 ), device = device , dtype = dtype )
950
- if any (self .voltage > 0 ):
951
- tm [self .voltage > 0 ] = self ._cavity_rmatrix (energy [self .voltage > 0 ])
952
- if any (self .voltage <= 0 ):
953
- tm [self .voltage <= 0 ] = base_rmatrix (
954
- length = self .length [self .voltage <= 0 ],
955
- k1 = torch .zeros_like (
956
- self .length [self .voltage <= 0 ], device = device , dtype = dtype
957
- ),
958
- hx = torch .zeros_like (
959
- self .length [self .voltage <= 0 ], device = device , dtype = dtype
960
- ),
961
- tilt = torch .zeros_like (
962
- self .length [self .voltage <= 0 ], device = device , dtype = dtype
963
- ),
964
- energy = energy [self .voltage <= 0 ],
965
- )
945
+ # There used to be a check for voltage > 0 here, where the cavity transfer map
946
+ # was only computed for the elements with voltage > 0 and a basermatrix was
947
+ # used otherwise. This was removed because it was causing issues with the
948
+ # vectorisation, but I am not sure it is okay to remove.
949
+ tm = self ._cavity_rmatrix (energy )
966
950
967
951
return tm
968
952
@@ -990,11 +974,12 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
990
974
igamma2 = torch .full_like (self .length , 0.0 )
991
975
g0 = torch .full_like (self .length , 1e10 )
992
976
993
- g0 [incoming .energy != 0 ] = incoming .energy / electron_mass_eV .to (
977
+ mask = incoming .energy != 0
978
+ g0 [mask ] = incoming .energy [mask ] / electron_mass_eV .to (
994
979
device = device , dtype = dtype
995
980
)
996
- igamma2 [incoming . energy != 0 ] = 1 / g0 [incoming . energy != 0 ] ** 2
997
- beta0 [incoming . energy != 0 ] = torch .sqrt (1 - igamma2 [incoming . energy != 0 ])
981
+ igamma2 [mask ] = 1 / g0 [mask ] ** 2
982
+ beta0 [mask ] = torch .sqrt (1 - igamma2 [mask ])
998
983
999
984
phi = torch .deg2rad (self .phase )
1000
985
@@ -1012,22 +997,22 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
1012
997
T556 = torch .full_like (self .length , 0.0 )
1013
998
T555 = torch .full_like (self .length , 0.0 )
1014
999
1015
- if any (incoming .energy + delta_energy > 0 ):
1000
+ if torch . any (incoming .energy + delta_energy > 0 ):
1016
1001
k = 2 * torch .pi * self .frequency / constants .speed_of_light
1017
1002
outgoing_energy = incoming .energy + delta_energy
1018
1003
g1 = outgoing_energy / electron_mass_eV
1019
1004
beta1 = torch .sqrt (1 - 1 / g1 ** 2 )
1020
1005
1021
1006
if isinstance (incoming , ParameterBeam ):
1022
- outgoing_mu [: , 5 ] = incoming ._mu [: , 5 ] * incoming .energy * beta0 / (
1007
+ outgoing_mu [... , 5 ] = incoming ._mu [... , 5 ] * incoming .energy * beta0 / (
1023
1008
outgoing_energy * beta1
1024
1009
) + self .voltage * beta0 / (outgoing_energy * beta1 ) * (
1025
- torch .cos (- incoming ._mu [: , 4 ] * beta0 * k + phi ) - torch .cos (phi )
1010
+ torch .cos (- incoming ._mu [... , 4 ] * beta0 * k + phi ) - torch .cos (phi )
1026
1011
)
1027
- outgoing_cov [: , 5 , 5 ] = incoming ._cov [: , 5 , 5 ]
1012
+ outgoing_cov [... , 5 , 5 ] = incoming ._cov [... , 5 , 5 ]
1028
1013
else : # ParticleBeam
1029
- outgoing_particles [:, : , 5 ] = incoming .particles [
1030
- :, : , 5
1014
+ outgoing_particles [... , 5 ] = incoming .particles [
1015
+ ... , 5
1031
1016
] * incoming .energy .unsqueeze (- 1 ) * beta0 .unsqueeze (- 1 ) / (
1032
1017
outgoing_energy .unsqueeze (- 1 ) * beta1 .unsqueeze (- 1 )
1033
1018
) + self .voltage .unsqueeze (
@@ -1038,7 +1023,7 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
1038
1023
outgoing_energy .unsqueeze (- 1 ) * beta1 .unsqueeze (- 1 )
1039
1024
) * (
1040
1025
torch .cos (
1041
- incoming .particles [:, : , 4 ]
1026
+ incoming .particles [... , 4 ]
1042
1027
* beta0 .unsqueeze (- 1 )
1043
1028
* k .unsqueeze (- 1 )
1044
1029
+ phi .unsqueeze (- 1 )
@@ -1047,7 +1032,7 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
1047
1032
)
1048
1033
1049
1034
dgamma = self .voltage / electron_mass_eV
1050
- if any (delta_energy > 0 ):
1035
+ if torch . any (delta_energy > 0 ):
1051
1036
T566 = (
1052
1037
self .length
1053
1038
* (beta0 ** 3 * g0 ** 3 - beta1 ** 3 * g1 ** 3 )
@@ -1086,29 +1071,29 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
1086
1071
)
1087
1072
1088
1073
if isinstance (incoming , ParameterBeam ):
1089
- outgoing_mu [: , 4 ] = outgoing_mu [: , 4 ] + (
1090
- T566 * incoming ._mu [: , 5 ] ** 2
1091
- + T556 * incoming ._mu [: , 4 ] * incoming ._mu [: , 5 ]
1092
- + T555 * incoming ._mu [: , 4 ] ** 2
1074
+ outgoing_mu [... , 4 ] = outgoing_mu [... , 4 ] + (
1075
+ T566 * incoming ._mu [... , 5 ] ** 2
1076
+ + T556 * incoming ._mu [... , 4 ] * incoming ._mu [... , 5 ]
1077
+ + T555 * incoming ._mu [... , 4 ] ** 2
1093
1078
)
1094
- outgoing_cov [: , 4 , 4 ] = (
1095
- T566 * incoming ._cov [: , 5 , 5 ] ** 2
1096
- + T556 * incoming ._cov [: , 4 , 5 ] * incoming ._cov [: , 5 , 5 ]
1097
- + T555 * incoming ._cov [: , 4 , 4 ] ** 2
1079
+ outgoing_cov [... , 4 , 4 ] = (
1080
+ T566 * incoming ._cov [... , 5 , 5 ] ** 2
1081
+ + T556 * incoming ._cov [... , 4 , 5 ] * incoming ._cov [... , 5 , 5 ]
1082
+ + T555 * incoming ._cov [... , 4 , 4 ] ** 2
1098
1083
)
1099
- outgoing_cov [: , 4 , 5 ] = (
1100
- T566 * incoming ._cov [: , 5 , 5 ] ** 2
1101
- + T556 * incoming ._cov [: , 4 , 5 ] * incoming ._cov [: , 5 , 5 ]
1102
- + T555 * incoming ._cov [: , 4 , 4 ] ** 2
1084
+ outgoing_cov [... , 4 , 5 ] = (
1085
+ T566 * incoming ._cov [... , 5 , 5 ] ** 2
1086
+ + T556 * incoming ._cov [... , 4 , 5 ] * incoming ._cov [... , 5 , 5 ]
1087
+ + T555 * incoming ._cov [... , 4 , 4 ] ** 2
1103
1088
)
1104
- outgoing_cov [: , 5 , 4 ] = outgoing_cov [: , 4 , 5 ]
1089
+ outgoing_cov [... , 5 , 4 ] = outgoing_cov [... , 4 , 5 ]
1105
1090
else : # ParticleBeam
1106
- outgoing_particles [:, :, 4 ] = outgoing_particles [:, : , 4 ] + (
1107
- T566 .unsqueeze (- 1 ) * incoming .particles [:, : , 5 ] ** 2
1091
+ outgoing_particles [..., 4 ] = outgoing_particles [... , 4 ] + (
1092
+ T566 .unsqueeze (- 1 ) * incoming .particles [... , 5 ] ** 2
1108
1093
+ T556 .unsqueeze (- 1 )
1109
- * incoming .particles [:, : , 4 ]
1110
- * incoming .particles [:, : , 5 ]
1111
- + T555 .unsqueeze (- 1 ) * incoming .particles [:, : , 4 ] ** 2
1094
+ * incoming .particles [... , 4 ]
1095
+ * incoming .particles [... , 5 ]
1096
+ + T555 .unsqueeze (- 1 ) * incoming .particles [... , 4 ] ** 2
1112
1097
)
1113
1098
1114
1099
if isinstance (incoming , ParameterBeam ):
@@ -1143,7 +1128,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
1143
1128
Ei = energy / electron_mass_eV
1144
1129
Ef = (energy + delta_energy ) / electron_mass_eV
1145
1130
Ep = (Ef - Ei ) / self .length # Derivative of the energy
1146
- assert all (Ei > 0 ), "Initial energy must be larger than 0"
1131
+ assert torch . all (Ei > 0 ), "Initial energy must be larger than 0"
1147
1132
1148
1133
alpha = torch .sqrt (eta / 8 ) / torch .cos (phi ) * torch .log (Ef / Ei )
1149
1134
@@ -1179,7 +1164,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
1179
1164
1180
1165
k = 2 * torch .pi * self .frequency / torch .tensor (constants .speed_of_light )
1181
1166
r55_cor = 0.0
1182
- if any ((self .voltage != 0 ) & (energy != 0 )): # TODO: Do we need this if?
1167
+ if torch . any ((self .voltage != 0 ) & (energy != 0 )): # TODO: Do we need this if?
1183
1168
beta0 = torch .sqrt (1 - 1 / Ei ** 2 )
1184
1169
beta1 = torch .sqrt (1 - 1 / Ef ** 2 )
1185
1170
@@ -1201,18 +1186,18 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
1201
1186
r65 = k * torch .sin (phi ) * self .voltage / (Ef * beta1 * electron_mass_eV )
1202
1187
1203
1188
R = torch .eye (7 , device = device , dtype = dtype ).repeat ((* self .length .shape , 1 , 1 ))
1204
- R [: , 0 , 0 ] = r11
1205
- R [: , 0 , 1 ] = r12
1206
- R [: , 1 , 0 ] = r21
1207
- R [: , 1 , 1 ] = r22
1208
- R [: , 2 , 2 ] = r11
1209
- R [: , 2 , 3 ] = r12
1210
- R [: , 3 , 2 ] = r21
1211
- R [: , 3 , 3 ] = r22
1212
- R [: , 4 , 4 ] = 1 + r55_cor
1213
- R [: , 4 , 5 ] = r56
1214
- R [: , 5 , 4 ] = r65
1215
- R [: , 5 , 5 ] = r66
1189
+ R [... , 0 , 0 ] = r11
1190
+ R [... , 0 , 1 ] = r12
1191
+ R [... , 1 , 0 ] = r21
1192
+ R [... , 1 , 1 ] = r22
1193
+ R [... , 2 , 2 ] = r11
1194
+ R [... , 2 , 3 ] = r12
1195
+ R [... , 3 , 2 ] = r21
1196
+ R [... , 3 , 3 ] = r22
1197
+ R [... , 4 , 4 ] = 1 + r55_cor
1198
+ R [... , 4 , 5 ] = r56
1199
+ R [... , 5 , 4 ] = r65
1200
+ R [... , 5 , 5 ] = r66
1216
1201
1217
1202
return R
1218
1203
0 commit comments