Skip to content

Commit 8c7de4c

Browse files
jank324cr-xu
andauthored
Assume more reasonable default values for total_charge (#377)
* Assume more reasonable default values for `total_charge` * Update changelog * Fix failing test --------- Co-authored-by: Chenran Xu <xuchenran@gmail.com>
1 parent 2452153 commit 8c7de4c

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- The covariance properties of the different beam classes were renamed from names like `cor_x` and `sigma_xpx` to consistent names like `cov_xpx` (see #331) (@jank324)
1010
- The signature of the `transfer_map` method of all element subclasses was extended by a non-optional `species` argument (see #276) (@cr-xu, @jank324, @Hespe)
1111
- `ParticleBeam.plot_distribution` allows for Seaborn-style passing of `axs` and returns the latter as well. In line with that change for the purpose of overlaying distributions, the `contour` argument of `ParticleBeam.plot_2d_distribution` was replaced by a `style` argument. (see #330) (@jank324)
12+
- The default values for `total_charge` in both beam classes are no longer `0.0` but more sensible values (see #377) (@jank324)
1213
- `ParameterBeam._mu` and `ParameterBeam._cov` were renamed to `ParameterBeam.mu` and `ParameterBeam.cov` (see #378) (@jank324)
1314

1415
### 🚀 Features

cheetah/particles/parameter_beam.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,22 @@ def __init__(
3737
factory_kwargs = {"device": device, "dtype": dtype}
3838
super().__init__()
3939

40+
self.species = species if species is not None else Species("electron")
41+
4042
self.register_buffer("mu", None)
4143
self.register_buffer("cov", None)
4244
self.register_buffer("energy", None)
43-
self.register_buffer("total_charge", torch.tensor(0.0, **factory_kwargs))
45+
self.register_buffer(
46+
"total_charge",
47+
torch.as_tensor(self.species.charge_coulomb, **factory_kwargs),
48+
)
4449

4550
self.mu = torch.as_tensor(mu, **factory_kwargs)
4651
self.cov = torch.as_tensor(cov, **factory_kwargs)
4752
self.energy = torch.as_tensor(energy, **factory_kwargs)
4853
if total_charge is not None:
4954
self.total_charge = torch.as_tensor(total_charge, **factory_kwargs)
5055

51-
self.species = species if species is not None else Species("electron")
52-
5356
@classmethod
5457
def from_parameters(
5558
cls,

cheetah/particles/particle_beam.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,8 @@ def make_linspaced(
564564
)
565565
factory_kwargs = {"device": device, "dtype": dtype}
566566

567+
species = species if species is not None else Species("electron")
568+
567569
# Set default values without function call in function signature
568570
mu_x = mu_x if mu_x is not None else torch.tensor(0.0, **factory_kwargs)
569571
mu_px = mu_px if mu_px is not None else torch.tensor(0.0, **factory_kwargs)
@@ -593,7 +595,7 @@ def make_linspaced(
593595
total_charge = (
594596
total_charge
595597
if total_charge is not None
596-
else torch.tensor(0.0, **factory_kwargs)
598+
else species.charge_coulomb * num_particles
597599
)
598600
particle_charges = (
599601
torch.ones((*total_charge.shape, num_particles), **factory_kwargs)

0 commit comments

Comments
 (0)