Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix twiss plot #213

Merged
merged 19 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This is a major release with significant upgrades under the hood of Cheetah. Des

### 🚨 Breaking Changes

- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe, @roussel-ryan)
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198, #208, #213, #215, #218, #229, #233) (@jank324, @cr-xu, @hespe, @roussel-ryan)
- The fifth particle coordinate `s` is renamed to `tau`. Now Cheetah uses the canonical variables in phase space $(x,px=\frac{P_x}{p_0},y,py, \tau=c\Delta t, \delta=\Delta E/{p_0 c})$. In addition, the trailing "s" was removed from some beam property names (e.g. `beam.xs` becomes `beam.x`). (see #163) (@cr-xu)
- `Screen` no longer blocks the beam (by default). To return to old behaviour, set `Screen.is_blocking = True`. (see #208) (@jank324, @roussel-ryan)

Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def plot(self, ax: plt.Axes, s: float) -> None:
height = 0.4

patch = Rectangle(
(s, 0), self.length[0], height, color="gold", alpha=alpha, zorder=2
(s, 0), self.length, height, color="gold", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,5 @@ def split(self, resolution: torch.Tensor) -> list[Element]:
def plot(self, ax: plt.Axes, s: float) -> None:
height = 0.4

patch = Rectangle((s, 0), self.length[0], height, color="tab:olive", zorder=2)
patch = Rectangle((s, 0), self.length, height, color="tab:olive", zorder=2)
ax.add_patch(patch)
4 changes: 2 additions & 2 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,9 @@ def defining_features(self) -> list[str]:

def plot(self, ax: plt.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1)
height = 0.8 * (np.sign(self.angle) if self.is_active else 1)

patch = Rectangle(
(s, 0), self.length[0], height, color="tab:green", alpha=alpha, zorder=2
(s, 0), self.length, height, color="tab:green", alpha=alpha, zorder=2
)
ax.add_patch(patch)
4 changes: 2 additions & 2 deletions cheetah/accelerator/horizontal_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def split(self, resolution: torch.Tensor) -> list[Element]:

def plot(self, ax: plt.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1)
height = 0.8 * (np.sign(self.angle) if self.is_active else 1)

patch = Rectangle(
(s, 0), self.length[0], height, color="tab:blue", alpha=alpha, zorder=2
(s, 0), self.length, height, color="tab:blue", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
4 changes: 2 additions & 2 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ def split(self, resolution: torch.Tensor) -> list[Element]:

def plot(self, ax: plt.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.k1[0]) if self.is_active else 1)
height = 0.8 * (np.sign(self.k1) if self.is_active else 1)
patch = Rectangle(
(s, 0), self.length[0], height, color="tab:red", alpha=alpha, zorder=2
(s, 0), self.length, height, color="tab:red", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
8 changes: 4 additions & 4 deletions cheetah/accelerator/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def split(self, resolution: torch.Tensor) -> list[Element]:
]

def plot(self, ax: plt.Axes, s: float) -> None:
element_lengths = [element.length[0] for element in self.elements]
element_lengths = [element.length for element in self.elements]
element_ss = [0] + [
sum(element_lengths[: i + 1]) for i, _ in enumerate(element_lengths)
]
Expand Down Expand Up @@ -423,7 +423,7 @@ def plot_reference_particle_traces(
reference_segment = deepcopy(self)
splits = reference_segment.split(resolution=torch.tensor(resolution))

split_lengths = [split.length[0] for split in splits]
split_lengths = [split.length for split in splits]
ss = [0] + [sum(split_lengths[: i + 1]) for i, _ in enumerate(split_lengths)]

references = []
Expand Down Expand Up @@ -464,7 +464,7 @@ def plot_reference_particle_traces(

for particle_index in range(num_particles):
xs = [
float(reference_beam.x[0, particle_index].cpu())
reference_beam.x[particle_index]
for reference_beam in references
if reference_beam is not Beam.empty
]
Expand All @@ -475,7 +475,7 @@ def plot_reference_particle_traces(

for particle_index in range(num_particles):
ys = [
float(reference_beam.ys[0, particle_index].cpu())
reference_beam.y[particle_index]
for reference_beam in references
if reference_beam is not Beam.empty
]
Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/solenoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def plot(self, ax: plt.Axes, s: float) -> None:
height = 0.8

patch = Rectangle(
(s, 0), self.length[0], height, color="tab:orange", alpha=alpha, zorder=2
(s, 0), self.length, height, color="tab:orange", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/transverse_deflecting_cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def plot(self, ax: plt.Axes, s: float) -> None:
height = 0.4

patch = Rectangle(
(s, 0), self.length[0], height, color="olive", alpha=alpha, zorder=2
(s, 0), self.length, height, color="olive", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/undulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def plot(self, ax: plt.Axes, s: float) -> None:
height = 0.4

patch = Rectangle(
(s, 0), self.length[0], height, color="tab:purple", alpha=alpha, zorder=2
(s, 0), self.length, height, color="tab:purple", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
4 changes: 2 additions & 2 deletions cheetah/accelerator/vertical_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def split(self, resolution: torch.Tensor) -> list[Element]:

def plot(self, ax: plt.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1)
height = 0.8 * (np.sign(self.angle) if self.is_active else 1)

patch = Rectangle(
(s, 0), self.length[0], height, color="tab:cyan", alpha=alpha, zorder=2
(s, 0), self.length, height, color="tab:cyan", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down
138 changes: 38 additions & 100 deletions cheetah/particles/particle_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from scipy.constants import physical_constants
from torch.distributions import MultivariateNormal

from ..utils import elementwise_linspace
from .beam import Beam

speed_of_light = torch.tensor(constants.speed_of_light) # In m/s
Expand Down Expand Up @@ -450,118 +451,55 @@ def make_linspaced(
:param device: Device to move the beam's particle array to. If set to `"auto"` a
CUDA GPU is selected if available. The CPU is used otherwise.
"""
# Figure out if arguments were passed, figure out their shape
not_nones = [
argument
for argument in [
mu_x,
mu_px,
mu_y,
mu_py,
sigma_x,
sigma_px,
sigma_y,
sigma_py,
sigma_tau,
sigma_p,
energy,
total_charge,
]
if argument is not None
]
shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([1])
if len(not_nones) > 1:
assert all(
argument.shape == shape for argument in not_nones
), "Arguments must have the same shape."

# Set default values without function call in function signature
num_particles = num_particles if num_particles is not None else torch.tensor(10)
mu_x = mu_x if mu_x is not None else torch.full(shape, 0.0)
mu_px = mu_px if mu_px is not None else torch.full(shape, 0.0)
mu_y = mu_y if mu_y is not None else torch.full(shape, 0.0)
mu_py = mu_py if mu_py is not None else torch.full(shape, 0.0)
sigma_x = sigma_x if sigma_x is not None else torch.full(shape, 175e-9)
sigma_px = sigma_px if sigma_px is not None else torch.full(shape, 2e-7)
sigma_y = sigma_y if sigma_y is not None else torch.full(shape, 175e-9)
sigma_py = sigma_py if sigma_py is not None else torch.full(shape, 2e-7)
sigma_tau = sigma_tau if sigma_tau is not None else torch.full(shape, 0.0)
sigma_p = sigma_p if sigma_p is not None else torch.full(shape, 0.0)
energy = energy if energy is not None else torch.full(shape, 1e8)
total_charge = (
total_charge if total_charge is not None else torch.full(shape, 0.0)
)

mu_x = mu_x if mu_x is not None else torch.tensor(0.0)
mu_px = mu_px if mu_px is not None else torch.tensor(0.0)
mu_y = mu_y if mu_y is not None else torch.tensor(0.0)
mu_py = mu_py if mu_py is not None else torch.tensor(0.0)
sigma_x = sigma_x if sigma_x is not None else torch.tensor(175e-9)
sigma_px = sigma_px if sigma_px is not None else torch.tensor(2e-7)
sigma_y = sigma_y if sigma_y is not None else torch.tensor(175e-9)
sigma_py = sigma_py if sigma_py is not None else torch.tensor(2e-7)
sigma_tau = sigma_tau if sigma_tau is not None else torch.tensor(1e-6)
sigma_p = sigma_p if sigma_p is not None else torch.tensor(1e-6)
energy = energy if energy is not None else torch.tensor(1e8)
total_charge = total_charge if total_charge is not None else torch.tensor(0.0)
particle_charges = (
torch.ones((shape[0], num_particles), device=device, dtype=dtype)
* total_charge.view(-1, 1)
torch.ones((*total_charge.shape, num_particles))
* total_charge.unsqueeze(-1)
/ num_particles
)

particles = torch.ones((shape[0], num_particles, 7))

particles[:, :, 0] = torch.stack(
[
torch.linspace(
sample_mu_x - sample_sigma_x,
sample_mu_x + sample_sigma_x,
num_particles,
)
for sample_mu_x, sample_sigma_x in zip(mu_x, sigma_x)
],
dim=0,
vector_shape = torch.broadcast_shapes(
mu_x.shape,
mu_px.shape,
mu_y.shape,
mu_py.shape,
sigma_x.shape,
sigma_px.shape,
sigma_y.shape,
sigma_py.shape,
sigma_tau.shape,
sigma_p.shape,
)
particles[:, :, 1] = torch.stack(
[
torch.linspace(
sample_mu_px - sample_sigma_px,
sample_mu_px + sample_sigma_px,
num_particles,
)
for sample_mu_px, sample_sigma_px in zip(mu_px, sigma_px)
],
dim=0,
)
particles[:, :, 2] = torch.stack(
[
torch.linspace(
sample_mu_y - sample_sigma_y,
sample_mu_y + sample_sigma_y,
num_particles,
)
for sample_mu_y, sample_sigma_y in zip(mu_y, sigma_y)
],
dim=0,
particles = torch.ones((*vector_shape, num_particles, 7))

particles[..., 0] = elementwise_linspace(
mu_x - sigma_x, mu_x + sigma_x, num_particles
)
particles[:, :, 3] = torch.stack(
[
torch.linspace(
sample_mu_py - sample_sigma_py,
sample_mu_py + sample_sigma_py,
num_particles,
)
for sample_mu_py, sample_sigma_py in zip(mu_py, sigma_py)
],
dim=0,
particles[..., 1] = elementwise_linspace(
mu_px - sigma_px, mu_px + sigma_px, num_particles
)
particles[:, :, 4] = torch.stack(
[
torch.linspace(
-sample_sigma_tau, sample_sigma_tau, num_particles, device=device
)
for sample_sigma_tau in sigma_tau
],
dim=0,
particles[..., 2] = elementwise_linspace(
mu_y - sigma_y, mu_y + sigma_y, num_particles
)
particles[:, :, 5] = torch.stack(
[
torch.linspace(
-sample_sigma_p, sample_sigma_p, num_particles, device=device
)
for sample_sigma_p in sigma_p
],
dim=0,
particles[..., 3] = elementwise_linspace(
mu_py - sigma_py, mu_py + sigma_py, num_particles
)
particles[..., 4] = elementwise_linspace(-sigma_tau, sigma_tau, num_particles)
particles[..., 5] = elementwise_linspace(-sigma_p, sigma_p, num_particles)

return cls(
particles=particles,
Expand Down
1 change: 1 addition & 0 deletions cheetah/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import bmadx # noqa: F401
from .device import is_mps_available_and_functional # noqa: F401
from .elementwise_linspace import elementwise_linspace # noqa: F401
from .kde import kde_histogram_1d, kde_histogram_2d # noqa: F401
from .physics import compute_relativistic_factors # noqa: F401
from .unique_name_generator import UniqueNameGenerator # noqa: F401
35 changes: 35 additions & 0 deletions cheetah/utils/elementwise_linspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch


def elementwise_linspace(
start: torch.Tensor, end: torch.Tensor, steps: int
) -> torch.Tensor:
"""
Generate a tensor of linearly spaced values between two tensors element-wise.

:param start: Any-dimensional tensor of the starting value for the set of points.
:param end: Any-dimensional tensor of the ending value for the set of points.
:param steps: Size of the last dimension of the constructed tensor.
:return: A tensor of shape `start.shape + (steps,)` containing `steps` linearly
spaced values between each pair of elements in `start` and `end`.
"""
# Flatten the tensors
a_flat = start.flatten()
b_flat = end.flatten()

# Create a list to store the results
result = []

# Generate linspace for each pair of elements in a and b
for i in range(a_flat.shape[0]):
result.append(torch.linspace(a_flat[i], b_flat[i], steps))

# Stack the results along a new dimension (each linspace will become a row)
result = torch.stack(result)

# Reshape back to the original tensor dimensions with one extra dimension for the
# steps
new_shape = list(start.shape) + [steps]
result = result.view(*new_shape)

return result
32 changes: 11 additions & 21 deletions docs/examples/simple.ipynb

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions tests/test_elementwise_linspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch

from cheetah.utils import elementwise_linspace


def test_example():
""" "Tests an example case with two 2D tensors."""
start = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
end = torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]])
steps = 5

result = elementwise_linspace(start, end, steps)

# Check shape
assert result.shape == (2, 3, 5)

# Check that edges are correct
assert torch.allclose(result[:, :, 0], start)
assert torch.allclose(result[:, :, -1], end)

# Check that the values are linearly interpolated for each linspace
assert torch.allclose(result[0, 0, :], torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]))
assert torch.allclose(result[0, 1, :], torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0]))
assert torch.allclose(result[0, 2, :], torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0]))
assert torch.allclose(result[1, 0, :], torch.tensor([4.0, 5.0, 6.0, 7.0, 8.0]))
assert torch.allclose(result[1, 1, :], torch.tensor([5.0, 6.0, 7.0, 8.0, 9.0]))
assert torch.allclose(result[1, 2, :], torch.tensor([6.0, 7.0, 8.0, 9.0, 10.0]))
Loading
Loading