Skip to content

Commit de10473

Browse files
authored
Merge branch 'master' into fix-ares-rl-speed
2 parents 05a032c + 8c7de4c commit de10473

8 files changed

+64
-60
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
- The covariance properties of the different beam classes were renamed from names like `cor_x` and `sigma_xpx` to consistent names like `cov_xpx` (see #331) (@jank324)
1010
- The signature of the `transfer_map` method of all element subclasses was extended by a non-optional `species` argument (see #276) (@cr-xu, @jank324, @Hespe)
1111
- `ParticleBeam.plot_distribution` allows for Seaborn-style passing of `axs` and returns the latter as well. In line with that change for the purpose of overlaying distributions, the `contour` argument of `ParticleBeam.plot_2d_distribution` was replaced by a `style` argument. (see #330) (@jank324)
12+
- The default values for `total_charge` in both beam classes are no longer `0.0` but more sensible values (see #377) (@jank324)
13+
- `ParameterBeam._mu` and `ParameterBeam._cov` were renamed to `ParameterBeam.mu` and `ParameterBeam.cov` (see #378) (@jank324)
1214

1315
### 🚀 Features
1416

cheetah/accelerator/cavity.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ def _track_beam(self, incoming: Beam) -> Beam:
111111

112112
tm = self.transfer_map(incoming.energy, incoming.species)
113113
if isinstance(incoming, ParameterBeam):
114-
outgoing_mu = torch.matmul(tm, incoming._mu.unsqueeze(-1)).squeeze(-1)
114+
outgoing_mu = torch.matmul(tm, incoming.mu.unsqueeze(-1)).squeeze(-1)
115115
outgoing_cov = torch.matmul(
116-
tm, torch.matmul(incoming._cov, tm.transpose(-2, -1))
116+
tm, torch.matmul(incoming.cov, tm.transpose(-2, -1))
117117
)
118118
else: # ParticleBeam
119119
outgoing_particles = torch.matmul(incoming.particles, tm.transpose(-2, -1))
@@ -133,12 +133,12 @@ def _track_beam(self, incoming: Beam) -> Beam:
133133
)
134134

135135
if isinstance(incoming, ParameterBeam):
136-
outgoing_mu[..., 5] = incoming._mu[..., 5] * incoming.energy * beta0 / (
136+
outgoing_mu[..., 5] = incoming.mu[..., 5] * incoming.energy * beta0 / (
137137
outgoing_energy * beta1
138138
) + self.voltage * beta0 / (outgoing_energy * beta1) * (
139-
torch.cos(-incoming._mu[..., 4] * beta0 * k + phi) - torch.cos(phi)
139+
torch.cos(-incoming.mu[..., 4] * beta0 * k + phi) - torch.cos(phi)
140140
)
141-
outgoing_cov[..., 5, 5] = incoming._cov[..., 5, 5]
141+
outgoing_cov[..., 5, 5] = incoming.cov[..., 5, 5]
142142
else: # ParticleBeam
143143
outgoing_particles[..., 5] = incoming.particles[
144144
..., 5
@@ -202,19 +202,19 @@ def _track_beam(self, incoming: Beam) -> Beam:
202202

203203
if isinstance(incoming, ParameterBeam):
204204
outgoing_mu[..., 4] = outgoing_mu[..., 4] + (
205-
T566 * incoming._mu[..., 5] ** 2
206-
+ T556 * incoming._mu[..., 4] * incoming._mu[..., 5]
207-
+ T555 * incoming._mu[..., 4] ** 2
205+
T566 * incoming.mu[..., 5] ** 2
206+
+ T556 * incoming.mu[..., 4] * incoming.mu[..., 5]
207+
+ T555 * incoming.mu[..., 4] ** 2
208208
)
209209
outgoing_cov[..., 4, 4] = (
210-
T566 * incoming._cov[..., 5, 5] ** 2
211-
+ T556 * incoming._cov[..., 4, 5] * incoming._cov[..., 5, 5]
212-
+ T555 * incoming._cov[..., 4, 4] ** 2
210+
T566 * incoming.cov[..., 5, 5] ** 2
211+
+ T556 * incoming.cov[..., 4, 5] * incoming.cov[..., 5, 5]
212+
+ T555 * incoming.cov[..., 4, 4] ** 2
213213
)
214214
outgoing_cov[..., 4, 5] = (
215-
T566 * incoming._cov[..., 5, 5] ** 2
216-
+ T556 * incoming._cov[..., 4, 5] * incoming._cov[..., 5, 5]
217-
+ T555 * incoming._cov[..., 4, 4] ** 2
215+
T566 * incoming.cov[..., 5, 5] ** 2
216+
+ T556 * incoming.cov[..., 4, 5] * incoming.cov[..., 5, 5]
217+
+ T555 * incoming.cov[..., 4, 4] ** 2
218218
)
219219
outgoing_cov[..., 5, 4] = outgoing_cov[..., 4, 5]
220220
else: # ParticleBeam

cheetah/accelerator/element.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def track(self, incoming: Beam) -> Beam:
6868
"""
6969
if isinstance(incoming, ParameterBeam):
7070
tm = self.transfer_map(incoming.energy, incoming.species)
71-
mu = torch.matmul(tm, incoming._mu.unsqueeze(-1)).squeeze(-1)
72-
cov = torch.matmul(tm, torch.matmul(incoming._cov, tm.transpose(-2, -1)))
71+
mu = torch.matmul(tm, incoming.mu.unsqueeze(-1)).squeeze(-1)
72+
cov = torch.matmul(tm, torch.matmul(incoming.cov, tm.transpose(-2, -1)))
7373
return ParameterBeam(
7474
mu,
7575
cov,

cheetah/accelerator/screen.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,12 @@ def track(self, incoming: Beam) -> Beam:
166166

167167
if isinstance(incoming, ParameterBeam):
168168
broadcasted_mu, _ = torch.broadcast_tensors(
169-
copy_of_incoming._mu, self.misalignment[..., 0]
169+
copy_of_incoming.mu, self.misalignment[..., 0]
170170
)
171-
copy_of_incoming._mu = broadcasted_mu.clone()
171+
copy_of_incoming.mu = broadcasted_mu.clone()
172172

173-
copy_of_incoming._mu[..., 0] -= self.misalignment[..., 0]
174-
copy_of_incoming._mu[..., 2] -= self.misalignment[..., 1]
173+
copy_of_incoming.mu[..., 0] -= self.misalignment[..., 0]
174+
copy_of_incoming.mu[..., 2] -= self.misalignment[..., 1]
175175
elif isinstance(incoming, ParticleBeam):
176176
broadcasted_particles, _ = torch.broadcast_tensors(
177177
copy_of_incoming.particles,
@@ -192,8 +192,8 @@ def track(self, incoming: Beam) -> Beam:
192192
if self.is_active and self.is_blocking:
193193
if isinstance(incoming, ParameterBeam):
194194
return ParameterBeam(
195-
mu=incoming._mu,
196-
cov=incoming._cov,
195+
mu=incoming.mu,
196+
cov=incoming.cov,
197197
energy=incoming.energy,
198198
total_charge=torch.zeros_like(incoming.total_charge),
199199
)
@@ -222,23 +222,23 @@ def reading(self) -> torch.Tensor:
222222
dtype=self.misalignment.dtype,
223223
)
224224
elif isinstance(read_beam, ParameterBeam):
225-
if torch.numel(read_beam._mu[..., 0]) > 1:
225+
if torch.numel(read_beam.mu[..., 0]) > 1:
226226
raise NotImplementedError(
227227
"`Screen` does not support vectorization of `ParameterBeam`. "
228228
"Please use `ParticleBeam` instead. If this is a feature you would "
229229
"like to see, please open an issue on GitHub."
230230
)
231231

232232
transverse_mu = torch.stack(
233-
[read_beam._mu[..., 0], read_beam._mu[..., 2]], dim=-1
233+
[read_beam.mu[..., 0], read_beam.mu[..., 2]], dim=-1
234234
)
235235
transverse_cov = torch.stack(
236236
[
237237
torch.stack(
238-
[read_beam._cov[..., 0, 0], read_beam._cov[..., 0, 2]], dim=-1
238+
[read_beam.cov[..., 0, 0], read_beam.cov[..., 0, 2]], dim=-1
239239
),
240240
torch.stack(
241-
[read_beam._cov[..., 2, 0], read_beam._cov[..., 2, 2]], dim=-1
241+
[read_beam.cov[..., 2, 0], read_beam.cov[..., 2, 2]], dim=-1
242242
),
243243
],
244244
dim=-1,

cheetah/particles/parameter_beam.py

+27-27
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ def __init__(
3737
factory_kwargs = {"device": device, "dtype": dtype}
3838
super().__init__()
3939

40-
self.register_buffer_or_parameter("_mu", torch.as_tensor(mu, **factory_kwargs))
40+
self.species = species if species is not None else Species("electron")
41+
42+
self.register_buffer_or_parameter("mu", torch.as_tensor(mu, **factory_kwargs))
4143
self.register_buffer_or_parameter(
42-
"_cov", torch.as_tensor(cov, **factory_kwargs)
44+
"cov", torch.as_tensor(cov, **factory_kwargs)
4345
)
4446
self.register_buffer_or_parameter(
4547
"energy", torch.as_tensor(energy, **factory_kwargs)
@@ -51,8 +53,6 @@ def __init__(
5153
),
5254
)
5355

54-
self.species = species if species is not None else Species("electron")
55-
5656
@classmethod
5757
def from_parameters(
5858
cls,
@@ -471,8 +471,8 @@ def as_particle_beam(self, num_particles: int) -> "ParticleBeam": # noqa: F821
471471
cov_taup=self.cov_taup,
472472
energy=self.energy,
473473
total_charge=self.total_charge,
474-
device=self._mu.device,
475-
dtype=self._mu.dtype,
474+
device=self.mu.device,
475+
dtype=self.mu.dtype,
476476
)
477477

478478
def linspaced(self, num_particles: int) -> "ParticleBeam": # noqa: F821
@@ -502,82 +502,82 @@ def linspaced(self, num_particles: int) -> "ParticleBeam": # noqa: F821
502502
energy=self.energy,
503503
total_charge=self.total_charge,
504504
species=self.species,
505-
device=self._mu.device,
506-
dtype=self._mu.dtype,
505+
device=self.mu.device,
506+
dtype=self.mu.dtype,
507507
)
508508

509509
@property
510510
def mu_x(self) -> torch.Tensor:
511-
return self._mu[..., 0]
511+
return self.mu[..., 0]
512512

513513
@property
514514
def sigma_x(self) -> torch.Tensor:
515-
return torch.sqrt(torch.clamp_min(self._cov[..., 0, 0], 1e-20))
515+
return torch.sqrt(torch.clamp_min(self.cov[..., 0, 0], 1e-20))
516516

517517
@property
518518
def mu_px(self) -> torch.Tensor:
519-
return self._mu[..., 1]
519+
return self.mu[..., 1]
520520

521521
@property
522522
def sigma_px(self) -> torch.Tensor:
523-
return torch.sqrt(torch.clamp_min(self._cov[..., 1, 1], 1e-20))
523+
return torch.sqrt(torch.clamp_min(self.cov[..., 1, 1], 1e-20))
524524

525525
@property
526526
def mu_y(self) -> torch.Tensor:
527-
return self._mu[..., 2]
527+
return self.mu[..., 2]
528528

529529
@property
530530
def sigma_y(self) -> torch.Tensor:
531-
return torch.sqrt(torch.clamp_min(self._cov[..., 2, 2], 1e-20))
531+
return torch.sqrt(torch.clamp_min(self.cov[..., 2, 2], 1e-20))
532532

533533
@property
534534
def mu_py(self) -> torch.Tensor:
535-
return self._mu[..., 3]
535+
return self.mu[..., 3]
536536

537537
@property
538538
def sigma_py(self) -> torch.Tensor:
539-
return torch.sqrt(torch.clamp_min(self._cov[..., 3, 3], 1e-20))
539+
return torch.sqrt(torch.clamp_min(self.cov[..., 3, 3], 1e-20))
540540

541541
@property
542542
def mu_tau(self) -> torch.Tensor:
543-
return self._mu[..., 4]
543+
return self.mu[..., 4]
544544

545545
@property
546546
def sigma_tau(self) -> torch.Tensor:
547-
return torch.sqrt(torch.clamp_min(self._cov[..., 4, 4], 1e-20))
547+
return torch.sqrt(torch.clamp_min(self.cov[..., 4, 4], 1e-20))
548548

549549
@property
550550
def mu_p(self) -> torch.Tensor:
551-
return self._mu[..., 5]
551+
return self.mu[..., 5]
552552

553553
@property
554554
def sigma_p(self) -> torch.Tensor:
555-
return torch.sqrt(torch.clamp_min(self._cov[..., 5, 5], 1e-20))
555+
return torch.sqrt(torch.clamp_min(self.cov[..., 5, 5], 1e-20))
556556

557557
@property
558558
def cov_xpx(self) -> torch.Tensor:
559-
return self._cov[..., 0, 1]
559+
return self.cov[..., 0, 1]
560560

561561
@property
562562
def cov_ypy(self) -> torch.Tensor:
563-
return self._cov[..., 2, 3]
563+
return self.cov[..., 2, 3]
564564

565565
@property
566566
def cov_taup(self) -> torch.Tensor:
567-
return self._cov[..., 4, 5]
567+
return self.cov[..., 4, 5]
568568

569569
def clone(self) -> "ParameterBeam":
570570
return self.__class__(
571-
mu=self._mu.clone(),
572-
cov=self._cov.clone(),
571+
mu=self.mu.clone(),
572+
cov=self.cov.clone(),
573573
energy=self.energy.clone(),
574574
total_charge=self.total_charge.clone(),
575575
)
576576

577577
def __repr__(self) -> str:
578578
return (
579-
f"{self.__class__.__name__}(mu={repr(self._mu)}, "
580-
+ f"cov={repr(self._cov)}, "
579+
f"{self.__class__.__name__}(mu={repr(self.mu)}, "
580+
+ f"cov={repr(self.cov)}, "
581581
+ f"energy={repr(self.energy)}, "
582582
+ f"total_charge={repr(self.total_charge)}, "
583583
+ f"species={repr(self.species)})"

cheetah/particles/particle_beam.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,8 @@ def make_linspaced(
570570
)
571571
factory_kwargs = {"device": device, "dtype": dtype}
572572

573+
species = species if species is not None else Species("electron")
574+
573575
# Set default values without function call in function signature
574576
mu_x = mu_x if mu_x is not None else torch.tensor(0.0, **factory_kwargs)
575577
mu_px = mu_px if mu_px is not None else torch.tensor(0.0, **factory_kwargs)
@@ -599,7 +601,7 @@ def make_linspaced(
599601
total_charge = (
600602
total_charge
601603
if total_charge is not None
602-
else torch.tensor(0.0, **factory_kwargs)
604+
else species.charge_coulomb * num_particles
603605
)
604606
particle_charges = (
605607
torch.ones((*total_charge.shape, num_particles), **factory_kwargs)

tests/test_differentiable.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ def test_ea_incoming_parameter_beam():
6767
"tests/resources/ACHIP_EA1_2021.1351.001"
6868
)
6969

70-
incoming_beam._mu = nn.Parameter(incoming_beam._mu)
71-
incoming_beam._cov = nn.Parameter(incoming_beam._cov)
70+
incoming_beam.mu = nn.Parameter(incoming_beam.mu)
71+
incoming_beam.cov = nn.Parameter(incoming_beam.cov)
7272

7373
outgoing_beam = ea.track(incoming_beam)
7474

75-
assert outgoing_beam._mu.grad_fn is not None
76-
assert outgoing_beam._cov.grad_fn is not None
75+
assert outgoing_beam.mu.grad_fn is not None
76+
assert outgoing_beam.cov.grad_fn is not None
7777

7878

7979
def test_ea_incoming_particle_beam():

tests/test_parameter_beam.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def test_from_twiss_dtype():
121121
assert np.isclose(beam.emittance_y.cpu().numpy(), 3.497810737006068e-09)
122122
assert np.isclose(beam.energy.cpu().numpy(), 6e6)
123123

124-
assert beam._mu.dtype == torch.float64
125-
assert beam._cov.dtype == torch.float64
124+
assert beam.mu.dtype == torch.float64
125+
assert beam.cov.dtype == torch.float64
126126

127127

128128
def test_conversion_to_and_from_particle_beam():

0 commit comments

Comments
 (0)