@@ -379,13 +379,11 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
379
379
energy = energy ,
380
380
)
381
381
382
- if torch .all (self .misalignment [:, 0 ] == 0 ) and torch .all (
383
- self .misalignment [:, 1 ] == 0
384
- ):
382
+ if torch .all (self .misalignment == 0 ):
385
383
return R
386
384
else :
387
- R_exit , R_entry = misalignment_matrix (self .misalignment )
388
- R = torch .matmul ( R_exit , torch . matmul ( R , R_entry ) )
385
+ R_entry , R_exit = misalignment_matrix (self .misalignment )
386
+ R = torch .einsum ( "...ij,...jk,...kl->...il" , R_exit , R , R_entry )
389
387
return R
390
388
391
389
def broadcast (self , shape : Size ) -> Element :
@@ -542,23 +540,21 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
542
540
hx = self .hx ,
543
541
tilt = torch .zeros_like (self .length ),
544
542
energy = energy ,
545
- )
543
+ ) # Tilt is applied after adding edges
546
544
else : # Reduce to Thin-Corrector
547
545
R = torch .eye (7 , device = device , dtype = dtype ).repeat (
548
546
(* self .length .shape , 1 , 1 )
549
547
)
550
- R [: , 0 , 1 ] = self .length
551
- R [: , 2 , 6 ] = self .angle
552
- R [: , 2 , 3 ] = self .length
548
+ R [... , 0 , 1 ] = self .length
549
+ R [... , 2 , 6 ] = self .angle
550
+ R [... , 2 , 3 ] = self .length
553
551
554
552
# Apply fringe fields
555
553
R = torch .matmul (R_exit , torch .matmul (R , R_enter ))
556
554
# Apply rotation for tilted magnets
557
- # TODO: Are we applying tilt twice (here and base_rmatrix)?
558
555
R = torch .matmul (
559
556
rotation_matrix (- self .tilt ), torch .matmul (R , rotation_matrix (self .tilt ))
560
557
)
561
-
562
558
return R
563
559
564
560
def _transfer_map_enter (self ) -> torch .Tensor :
@@ -576,8 +572,8 @@ def _transfer_map_enter(self) -> torch.Tensor:
576
572
)
577
573
578
574
tm = torch .eye (7 , device = device , dtype = dtype ).repeat (* phi .shape , 1 , 1 )
579
- tm [: , 1 , 0 ] = self .hx * torch .tan (self .e1 )
580
- tm [: , 3 , 2 ] = - self .hx * torch .tan (self .e1 - phi )
575
+ tm [... , 1 , 0 ] = self .hx * torch .tan (self .e1 )
576
+ tm [... , 3 , 2 ] = - self .hx * torch .tan (self .e1 - phi )
581
577
582
578
return tm
583
579
@@ -596,8 +592,8 @@ def _transfer_map_exit(self) -> torch.Tensor:
596
592
)
597
593
598
594
tm = torch .eye (7 , device = device , dtype = dtype ).repeat (* phi .shape , 1 , 1 )
599
- tm [: , 1 , 0 ] = self .hx * torch .tan (self .e2 )
600
- tm [: , 3 , 2 ] = - self .hx * torch .tan (self .e2 - phi )
595
+ tm [... , 1 , 0 ] = self .hx * torch .tan (self .e2 )
596
+ tm [... , 3 , 2 ] = - self .hx * torch .tan (self .e2 - phi )
601
597
602
598
return tm
603
599
@@ -1448,11 +1444,11 @@ def track(self, incoming: Beam) -> Beam:
1448
1444
copy_of_incoming = deepcopy (incoming )
1449
1445
1450
1446
if isinstance (incoming , ParameterBeam ):
1451
- copy_of_incoming ._mu [: , 0 ] -= self .misalignment [: , 0 ]
1452
- copy_of_incoming ._mu [: , 2 ] -= self .misalignment [: , 1 ]
1447
+ copy_of_incoming ._mu [... , 0 ] -= self .misalignment [... , 0 ]
1448
+ copy_of_incoming ._mu [... , 2 ] -= self .misalignment [... , 1 ]
1453
1449
elif isinstance (incoming , ParticleBeam ):
1454
- copy_of_incoming .particles [: , :, 0 ] -= self .misalignment [: , 0 ]
1455
- copy_of_incoming .particles [: , :, 1 ] -= self .misalignment [: , 1 ]
1450
+ copy_of_incoming .particles [... , :, 0 ] -= self .misalignment [... , 0 ]
1451
+ copy_of_incoming .particles [... , :, 1 ] -= self .misalignment [... , 1 ]
1456
1452
1457
1453
self .set_read_beam (copy_of_incoming )
1458
1454
@@ -1476,18 +1472,18 @@ def reading(self) -> torch.Tensor:
1476
1472
)
1477
1473
elif isinstance (read_beam , ParameterBeam ):
1478
1474
transverse_mu = torch .stack (
1479
- [read_beam ._mu [: , 0 ], read_beam ._mu [: , 2 ]], dim = 1
1475
+ [read_beam ._mu [... , 0 ], read_beam ._mu [... , 2 ]], dim = - 1
1480
1476
)
1481
1477
transverse_cov = torch .stack (
1482
1478
[
1483
1479
torch .stack (
1484
- [read_beam ._cov [: , 0 , 0 ], read_beam ._cov [: , 0 , 2 ]], dim = 1
1480
+ [read_beam ._cov [... , 0 , 0 ], read_beam ._cov [... , 0 , 2 ]], dim = - 1
1485
1481
),
1486
1482
torch .stack (
1487
- [read_beam ._cov [: , 2 , 0 ], read_beam ._cov [: , 2 , 2 ]], dim = 1
1483
+ [read_beam ._cov [... , 2 , 0 ], read_beam ._cov [... , 2 , 2 ]], dim = - 1
1488
1484
),
1489
1485
],
1490
- dim = 1 ,
1486
+ dim = - 1 ,
1491
1487
)
1492
1488
dist = [
1493
1489
MultivariateNormal (
@@ -1767,9 +1763,9 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
1767
1763
)
1768
1764
1769
1765
tm = torch .eye (7 , device = device , dtype = dtype ).repeat ((* energy .shape , 1 , 1 ))
1770
- tm [: , 0 , 1 ] = self .length
1771
- tm [: , 2 , 3 ] = self .length
1772
- tm [: , 4 , 5 ] = self .length * igamma2
1766
+ tm [... , 0 , 1 ] = self .length
1767
+ tm [... , 2 , 3 ] = self .length
1768
+ tm [... , 4 , 5 ] = self .length * igamma2
1773
1769
1774
1770
return tm
1775
1771
@@ -1866,38 +1862,38 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
1866
1862
r56 -= self .length / (beta * beta * gamma2 )
1867
1863
1868
1864
R = torch .eye (7 , device = device , dtype = dtype ).repeat ((* self .length .shape , 1 , 1 ))
1869
- R [: , 0 , 0 ] = c ** 2
1870
- R [: , 0 , 1 ] = c * s_k
1871
- R [: , 0 , 2 ] = s * c
1872
- R [: , 0 , 3 ] = s * s_k
1873
- R [: , 1 , 0 ] = - self .k * s * c
1874
- R [: , 1 , 1 ] = c ** 2
1875
- R [: , 1 , 2 ] = - self .k * s ** 2
1876
- R [: , 1 , 3 ] = s * c
1877
- R [: , 2 , 0 ] = - s * c
1878
- R [: , 2 , 1 ] = - s * s_k
1879
- R [: , 2 , 2 ] = c ** 2
1880
- R [: , 2 , 3 ] = c * s_k
1881
- R [: , 3 , 0 ] = self .k * s ** 2
1882
- R [: , 3 , 1 ] = - s * c
1883
- R [: , 3 , 2 ] = - self .k * s * c
1884
- R [: , 3 , 3 ] = c ** 2
1885
- R [: , 4 , 5 ] = r56
1865
+ R [... , 0 , 0 ] = c ** 2
1866
+ R [... , 0 , 1 ] = c * s_k
1867
+ R [... , 0 , 2 ] = s * c
1868
+ R [... , 0 , 3 ] = s * s_k
1869
+ R [... , 1 , 0 ] = - self .k * s * c
1870
+ R [... , 1 , 1 ] = c ** 2
1871
+ R [... , 1 , 2 ] = - self .k * s ** 2
1872
+ R [... , 1 , 3 ] = s * c
1873
+ R [... , 2 , 0 ] = - s * c
1874
+ R [... , 2 , 1 ] = - s * s_k
1875
+ R [... , 2 , 2 ] = c ** 2
1876
+ R [... , 2 , 3 ] = c * s_k
1877
+ R [... , 3 , 0 ] = self .k * s ** 2
1878
+ R [... , 3 , 1 ] = - s * c
1879
+ R [... , 3 , 2 ] = - self .k * s * c
1880
+ R [... , 3 , 3 ] = c ** 2
1881
+ R [... , 4 , 5 ] = r56
1886
1882
1887
1883
R = R .real
1888
1884
1889
- if self . misalignment [ 0 ] == 0 and self .misalignment [ 1 ] == 0 :
1885
+ if torch . all ( self .misalignment == 0 ) :
1890
1886
return R
1891
1887
else :
1892
- R_exit , R_entry = misalignment_matrix (self .misalignment )
1893
- R = torch .matmul ( R_exit , torch . matmul ( R , R_entry ) )
1888
+ R_entry , R_exit = misalignment_matrix (self .misalignment )
1889
+ R = torch .einsum ( "...ij,...jk,...kl->...il" , R_exit , R , R_entry )
1894
1890
return R
1895
1891
1896
1892
def broadcast (self , shape : Size ) -> Element :
1897
1893
return self .__class__ (
1898
1894
length = self .length .repeat (shape ),
1899
1895
k = self .k .repeat (shape ),
1900
- misalignment = self .misalignment .repeat (shape ),
1896
+ misalignment = self .misalignment .repeat (( * shape , 1 ) ),
1901
1897
name = self .name ,
1902
1898
)
1903
1899
0 commit comments