Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the batch execution for elements with tilts or misalignments #157

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

- Now all `Element` have a default length of `torch.zeros((1))`, fixing occasional issues with using elements without length, such as `Marker`, `BPM`, `Screen`, and `Aperture`. (see #143) (@cr-xu)
- Fix bug in `Cavity` `_track_beam` (see [#150](https://github.com/desy-ml/cheetah/issues/150)) (@jp-ga)
- Fix bugs when tracking elements with `tilt` or `misalignment`. (see #157) (@cr-xu)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cr-xu Did you fix anything other than the vectorisation? I guess because the changelog is with respect to the previous release (and not master), rather than adding a bug fix, we should add the issue and your name to the breaking vectorisation change entry. Does that make sense?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that's only with the vectorisation. Yes it makes sense to move it in the vectorisation entry.


### 🐆 Other

Expand Down
90 changes: 43 additions & 47 deletions cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,11 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
energy=energy,
)

if torch.all(self.misalignment[:, 0] == 0) and torch.all(
self.misalignment[:, 1] == 0
):
if torch.all(self.misalignment == 0):
return R
else:
R_exit, R_entry = misalignment_matrix(self.misalignment)
R = torch.matmul(R_exit, torch.matmul(R, R_entry))
R_entry, R_exit = misalignment_matrix(self.misalignment)
R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry)
return R

def broadcast(self, shape: Size) -> Element:
Expand Down Expand Up @@ -542,23 +540,21 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
hx=self.hx,
tilt=torch.zeros_like(self.length),
energy=energy,
)
) # Tilt is applied after adding edges
else: # Reduce to Thin-Corrector
R = torch.eye(7, device=device, dtype=dtype).repeat(
(*self.length.shape, 1, 1)
)
R[:, 0, 1] = self.length
R[:, 2, 6] = self.angle
R[:, 2, 3] = self.length
R[..., 0, 1] = self.length
R[..., 2, 6] = self.angle
R[..., 2, 3] = self.length

# Apply fringe fields
R = torch.matmul(R_exit, torch.matmul(R, R_enter))
# Apply rotation for tilted magnets
# TODO: Are we applying tilt twice (here and base_rmatrix)?
R = torch.matmul(
rotation_matrix(-self.tilt), torch.matmul(R, rotation_matrix(self.tilt))
)

return R

def _transfer_map_enter(self) -> torch.Tensor:
Expand All @@ -576,8 +572,8 @@ def _transfer_map_enter(self) -> torch.Tensor:
)

tm = torch.eye(7, device=device, dtype=dtype).repeat(*phi.shape, 1, 1)
tm[:, 1, 0] = self.hx * torch.tan(self.e1)
tm[:, 3, 2] = -self.hx * torch.tan(self.e1 - phi)
tm[..., 1, 0] = self.hx * torch.tan(self.e1)
tm[..., 3, 2] = -self.hx * torch.tan(self.e1 - phi)

return tm

Expand All @@ -596,8 +592,8 @@ def _transfer_map_exit(self) -> torch.Tensor:
)

tm = torch.eye(7, device=device, dtype=dtype).repeat(*phi.shape, 1, 1)
tm[:, 1, 0] = self.hx * torch.tan(self.e2)
tm[:, 3, 2] = -self.hx * torch.tan(self.e2 - phi)
tm[..., 1, 0] = self.hx * torch.tan(self.e2)
tm[..., 3, 2] = -self.hx * torch.tan(self.e2 - phi)

return tm

Expand Down Expand Up @@ -1448,11 +1444,11 @@ def track(self, incoming: Beam) -> Beam:
copy_of_incoming = deepcopy(incoming)

if isinstance(incoming, ParameterBeam):
copy_of_incoming._mu[:, 0] -= self.misalignment[:, 0]
copy_of_incoming._mu[:, 2] -= self.misalignment[:, 1]
copy_of_incoming._mu[..., 0] -= self.misalignment[..., 0]
copy_of_incoming._mu[..., 2] -= self.misalignment[..., 1]
elif isinstance(incoming, ParticleBeam):
copy_of_incoming.particles[:, :, 0] -= self.misalignment[:, 0]
copy_of_incoming.particles[:, :, 1] -= self.misalignment[:, 1]
copy_of_incoming.particles[..., :, 0] -= self.misalignment[..., 0]
copy_of_incoming.particles[..., :, 1] -= self.misalignment[..., 1]

self.set_read_beam(copy_of_incoming)

Expand All @@ -1476,18 +1472,18 @@ def reading(self) -> torch.Tensor:
)
elif isinstance(read_beam, ParameterBeam):
transverse_mu = torch.stack(
[read_beam._mu[:, 0], read_beam._mu[:, 2]], dim=1
[read_beam._mu[..., 0], read_beam._mu[..., 2]], dim=-1
)
transverse_cov = torch.stack(
[
torch.stack(
[read_beam._cov[:, 0, 0], read_beam._cov[:, 0, 2]], dim=1
[read_beam._cov[..., 0, 0], read_beam._cov[..., 0, 2]], dim=-1
),
torch.stack(
[read_beam._cov[:, 2, 0], read_beam._cov[:, 2, 2]], dim=1
[read_beam._cov[..., 2, 0], read_beam._cov[..., 2, 2]], dim=-1
),
],
dim=1,
dim=-1,
)
dist = [
MultivariateNormal(
Expand Down Expand Up @@ -1767,9 +1763,9 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
)

tm = torch.eye(7, device=device, dtype=dtype).repeat((*energy.shape, 1, 1))
tm[:, 0, 1] = self.length
tm[:, 2, 3] = self.length
tm[:, 4, 5] = self.length * igamma2
tm[..., 0, 1] = self.length
tm[..., 2, 3] = self.length
tm[..., 4, 5] = self.length * igamma2

return tm

Expand Down Expand Up @@ -1866,38 +1862,38 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
r56 -= self.length / (beta * beta * gamma2)

R = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1))
R[:, 0, 0] = c**2
R[:, 0, 1] = c * s_k
R[:, 0, 2] = s * c
R[:, 0, 3] = s * s_k
R[:, 1, 0] = -self.k * s * c
R[:, 1, 1] = c**2
R[:, 1, 2] = -self.k * s**2
R[:, 1, 3] = s * c
R[:, 2, 0] = -s * c
R[:, 2, 1] = -s * s_k
R[:, 2, 2] = c**2
R[:, 2, 3] = c * s_k
R[:, 3, 0] = self.k * s**2
R[:, 3, 1] = -s * c
R[:, 3, 2] = -self.k * s * c
R[:, 3, 3] = c**2
R[:, 4, 5] = r56
R[..., 0, 0] = c**2
R[..., 0, 1] = c * s_k
R[..., 0, 2] = s * c
R[..., 0, 3] = s * s_k
R[..., 1, 0] = -self.k * s * c
R[..., 1, 1] = c**2
R[..., 1, 2] = -self.k * s**2
R[..., 1, 3] = s * c
R[..., 2, 0] = -s * c
R[..., 2, 1] = -s * s_k
R[..., 2, 2] = c**2
R[..., 2, 3] = c * s_k
R[..., 3, 0] = self.k * s**2
R[..., 3, 1] = -s * c
R[..., 3, 2] = -self.k * s * c
R[..., 3, 3] = c**2
R[..., 4, 5] = r56

R = R.real

if self.misalignment[0] == 0 and self.misalignment[1] == 0:
if torch.all(self.misalignment == 0):
return R
else:
R_exit, R_entry = misalignment_matrix(self.misalignment)
R = torch.matmul(R_exit, torch.matmul(R, R_entry))
R_entry, R_exit = misalignment_matrix(self.misalignment)
R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry)
return R

def broadcast(self, shape: Size) -> Element:
return self.__class__(
length=self.length.repeat(shape),
k=self.k.repeat(shape),
misalignment=self.misalignment.repeat(shape),
misalignment=self.misalignment.repeat((*shape, 1)),
name=self.name,
)

Expand Down
37 changes: 20 additions & 17 deletions cheetah/track_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ def rotation_matrix(angle: torch.Tensor) -> torch.Tensor:
cs = torch.cos(angle)
sn = torch.sin(angle)

tm = torch.eye(7, dtype=angle.dtype, device=angle.device)
tm[0, 0] = cs
tm[0, 2] = sn
tm[1, 1] = cs
tm[1, 3] = sn
tm[2, 0] = -sn
tm[2, 2] = cs
tm[3, 1] = -sn
tm[3, 3] = cs
tm = torch.eye(7, dtype=angle.dtype, device=angle.device).repeat(*angle.shape, 1, 1)
tm[..., 0, 0] = cs
tm[..., 0, 2] = sn
tm[..., 1, 1] = cs
tm[..., 1, 3] = sn
tm[..., 2, 0] = -sn
tm[..., 2, 2] = cs
tm[..., 3, 1] = -sn
tm[..., 3, 3] = cs

return tm

Expand Down Expand Up @@ -98,7 +98,9 @@ def base_rmatrix(

# Rotate the R matrix for skew / vertical magnets
if torch.any(tilt != 0):
R = torch.matmul(torch.matmul(rotation_matrix(-tilt), R), rotation_matrix(tilt))
R = torch.einsum(
"...ij,...jk,...kl->...il", rotation_matrix(-tilt), R, rotation_matrix(tilt)
)
return R


Expand All @@ -108,13 +110,14 @@ def misalignment_matrix(
"""Shift the beam for tracking beam through misaligned elements"""
device = misalignment.device
dtype = misalignment.dtype
batch_shape = misalignment.shape[:-1]

R_exit = torch.eye(7, device=device, dtype=dtype)
R_exit[0, 6] = misalignment[0]
R_exit[2, 6] = misalignment[1]
R_exit = torch.eye(7, device=device, dtype=dtype).repeat(*batch_shape, 1, 1)
R_exit[..., 0, 6] = misalignment[..., 0]
R_exit[..., 2, 6] = misalignment[..., 1]

R_entry = torch.eye(7, device=device, dtype=dtype)
R_entry[0, 6] = -misalignment[0]
R_entry[2, 6] = -misalignment[1]
R_entry = torch.eye(7, device=device, dtype=dtype).repeat(*batch_shape, 1, 1)
R_entry[..., 0, 6] = -misalignment[..., 0]
R_entry[..., 2, 6] = -misalignment[..., 1]

return R_exit, R_entry # TODO: This order is confusing, should be entry, exit
return R_entry, R_exit
46 changes: 46 additions & 0 deletions tests/test_compare_ocelot.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,52 @@ def test_dipole_with_fringe_field():
)


def test_dipole_with_fringe_field_and_tilt():
"""
Test that the tracking results through a Cheeath `Dipole` element match those
through an Oclet `Bend` element when there are fringe fields and tilt, and the
e1 and e2 angles are set.
"""
# Cheetah
bend_angle = np.pi / 6
tilt_angle = np.pi / 4
incoming_beam = cheetah.ParticleBeam.from_astra(
"tests/resources/ACHIP_EA1_2021.1351.001"
)
cheetah_dipole = cheetah.Dipole(
length=torch.tensor([1.0]),
angle=torch.tensor([bend_angle]),
fringe_integral=torch.tensor([0.1]),
gap=torch.tensor([0.2]),
tilt=torch.tensor([tilt_angle]),
e1=torch.tensor([bend_angle / 2]),
e2=torch.tensor([bend_angle / 2]),
)
outgoing_beam = cheetah_dipole(incoming_beam)

# Ocelot
incoming_p_array = ocelot.astraBeam2particleArray(
"tests/resources/ACHIP_EA1_2021.1351.001"
)
ocelot_bend = ocelot.Bend(
l=1.0,
angle=bend_angle,
fint=0.1,
gap=0.2,
tilt=tilt_angle,
e1=bend_angle / 2,
e2=bend_angle / 2,
)
lattice = ocelot.MagneticLattice([ocelot_bend, ocelot.Marker()])
navigator = ocelot.Navigator(lattice)
_, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator)

assert np.allclose(
outgoing_beam.particles[0, :, :6].cpu().numpy(),
outgoing_p_array.rparticles.transpose(),
)


def test_aperture():
"""
Test that the tracking results through a Cheeath `Aperture` element match those
Expand Down
56 changes: 56 additions & 0 deletions tests/test_quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,59 @@ def test_quadrupole_off():

assert torch.allclose(outbeam_quad.sigma_x, outbeam_drift.sigma_x)
assert not torch.allclose(outbeam_quad_on.sigma_x, outbeam_drift.sigma_x)


def test_quadrupole_with_misalignments_batched():
"""
Test that a quadrupole with misalignments behaves as expected.
"""

quad_with_misalignment = Quadrupole(
length=torch.tensor([1.0]),
k1=torch.tensor([1.0]),
misalignment=torch.tensor([[0.1, 0.1]]),
)

quad_without_misalignment = Quadrupole(
length=torch.tensor([1.0]), k1=torch.tensor([1.0])
)
incoming_beam = ParameterBeam.from_parameters(
sigma_xp=torch.tensor([2e-7]), sigma_yp=torch.tensor([2e-7])
)
outbeam_quad_with_misalignment = quad_with_misalignment(incoming_beam)
outbeam_quad_without_misalignment = quad_without_misalignment(incoming_beam)

assert not torch.allclose(
outbeam_quad_with_misalignment.mu_x,
outbeam_quad_without_misalignment.mu_x,
)


def test_quadrupole_with_misalignments_multiple_batch_dimension():
"""
Test that a quadrupole with misalignments with multiple batch dimension.
"""
batch_shape = torch.Size([4, 3])
quad_with_misalignment = Quadrupole(
length=torch.tensor([1.0]),
k1=torch.tensor([1.0]),
misalignment=torch.tensor([[0.1, 0.1]]),
).broadcast(batch_shape)

quad_without_misalignment = Quadrupole(
length=torch.tensor([1.0]), k1=torch.tensor([1.0])
).broadcast(batch_shape)
incoming_beam = ParameterBeam.from_parameters(
sigma_xp=torch.tensor([2e-7]), sigma_yp=torch.tensor([2e-7])
).broadcast(batch_shape)
outbeam_quad_with_misalignment = quad_with_misalignment(incoming_beam)
outbeam_quad_without_misalignment = quad_without_misalignment(incoming_beam)

# Check that the misalignment has an effect
assert not torch.allclose(
outbeam_quad_with_misalignment.mu_x,
outbeam_quad_without_misalignment.mu_x,
)

# Check that the output shape is correct
assert outbeam_quad_with_misalignment.mu_x.shape == batch_shape
Loading