forked from desy-ml/cheetah
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquadrupole.py
124 lines (106 loc) · 3.96 KB
/
quadrupole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from typing import Optional, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from torch import Size, nn
from cheetah.track_methods import base_rmatrix, misalignment_matrix
from cheetah.utils import UniqueNameGenerator
from .element import Element
generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")
class Quadrupole(Element):
"""
Quadrupole magnet in a particle accelerator.
:param length: Length in meters.
:param k1: Strength of the quadrupole in rad/m.
:param misalignment: Misalignment vector of the quadrupole in x- and y-directions.
:param tilt: Tilt angle of the quadrupole in x-y plane [rad]. pi/4 for
skew-quadrupole.
:param name: Unique identifier of the element.
"""
def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
k1: Optional[Union[torch.Tensor, nn.Parameter]] = None,
misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None,
tilt: Optional[Union[torch.Tensor, nn.Parameter]] = None,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)
self.length = torch.as_tensor(length, **factory_kwargs)
self.k1 = (
torch.as_tensor(k1, **factory_kwargs)
if k1 is not None
else torch.zeros_like(self.length)
)
self.misalignment = (
torch.as_tensor(misalignment, **factory_kwargs)
if misalignment is not None
else torch.zeros((*self.length.shape, 2), **factory_kwargs)
)
self.tilt = (
torch.as_tensor(tilt, **factory_kwargs)
if tilt is not None
else torch.zeros_like(self.length)
)
def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
R = base_rmatrix(
length=self.length,
k1=self.k1,
hx=torch.zeros_like(self.length),
tilt=self.tilt,
energy=energy,
)
if torch.all(self.misalignment == 0):
return R
else:
R_entry, R_exit = misalignment_matrix(self.misalignment)
R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry)
return R
def broadcast(self, shape: Size) -> Element:
return self.__class__(
length=self.length.repeat(shape),
k1=self.k1.repeat(shape),
misalignment=self.misalignment.repeat((*shape, 1)),
tilt=self.tilt.repeat(shape),
name=self.name,
)
@property
def is_skippable(self) -> bool:
return True
@property
def is_active(self) -> bool:
return any(self.k1 != 0)
def split(self, resolution: torch.Tensor) -> list[Element]:
split_elements = []
remaining = self.length
while remaining > 0:
element = Quadrupole(
torch.min(resolution, remaining),
self.k1,
misalignment=self.misalignment,
)
split_elements.append(element)
remaining -= resolution
return split_elements
def plot(self, ax: plt.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.k1[0]) if self.is_active else 1)
patch = Rectangle(
(s, 0), self.length[0], height, color="tab:red", alpha=alpha, zorder=2
)
ax.add_patch(patch)
@property
def defining_features(self) -> list[str]:
return super().defining_features + ["length", "k1", "misalignment", "tilt"]
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(length={repr(self.length)}, "
+ f"k1={repr(self.k1)}, "
+ f"misalignment={repr(self.misalignment)}, "
+ f"tilt={repr(self.tilt)}, "
+ f"name={repr(self.name)})"
)