Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BmadQuadrupole element #153

Merged
merged 39 commits into from
Jul 24, 2024
Merged
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
ce88eeb
add `BmadQuadrupole` element
jp-ga May 8, 2024
20dbeb2
remove unused files
jp-ga May 8, 2024
b80bae8
Merge branch 'master' into 139-add-quadrupole-with-chromatic-effects
jp-ga May 8, 2024
ea30066
merge master into 139
jp-ga Jun 18, 2024
f74fb7c
implement bmadx quad
jp-ga Jun 20, 2024
257480d
cleanup
jp-ga Jun 20, 2024
5ce9da8
Merge branch 'master' into 139-add-quadrupole-with-chromatic-effects
jp-ga Jun 20, 2024
0b2ceca
merge hotfix
jp-ga Jun 20, 2024
a106cf0
run isort
jp-ga Jun 20, 2024
75c915d
run black
jp-ga Jun 20, 2024
4afbc7c
hotfix flake8
jp-ga Jun 20, 2024
0d702e6
hotfix flake8 again
jp-ga Jun 20, 2024
32c9adc
add missing docstrings and typing
jp-ga Jun 25, 2024
83c2f46
fix isort
jp-ga Jun 25, 2024
2da4d42
run black
jp-ga Jun 25, 2024
647d401
Merge branch 'master' into 139-add-quadrupole-with-chromatic-effects
jank324 Jun 26, 2024
90e4e30
Merge branch 'master' into 139-add-quadrupole-with-chromatic-effects
jank324 Jul 9, 2024
5c1c17b
Merge branch 'master' into 139-add-quadrupole-with-chromatic-effects
jank324 Jul 20, 2024
ac8202a
Move Bmad-X utils to new utils directory
jank324 Jul 20, 2024
3838550
A little cleanup
jank324 Jul 20, 2024
f5ffdaf
Fix import error
jank324 Jul 20, 2024
94b65a9
Light refactoring
jank324 Jul 20, 2024
ddbfb54
Fix typo in PR template
jank324 Jul 20, 2024
ca5ca6b
Reduce cope duplication in `Quadrupole.track`
jank324 Jul 20, 2024
5ad2de2
Fix Bmad-X quadrupole dev notebook
jank324 Jul 20, 2024
073a9bc
Simplify `is_skippable`
jank324 Jul 20, 2024
c5079fd
Clean up test
jank324 Jul 20, 2024
0115d3a
Add a test that finds Ryan's error
jank324 Jul 20, 2024
340ed80
Fix vectorisation issue with Bmad-X quadrupole tracking
jank324 Jul 20, 2024
8505c4e
Rearrange test reources for Bmad-X quadrupole implementation
jank324 Jul 20, 2024
6a15c64
Add changelog entry
jank324 Jul 20, 2024
67b746c
add num_steps and tracking_method to split
jp-ga Jul 22, 2024
3ec7789
Apply suggestions from code review
cr-xu Jul 23, 2024
7cbcaf4
Update cheetah/utils/bmadx.py
cr-xu Jul 23, 2024
3c851ea
Update cheetah/utils/bmadx.py
cr-xu Jul 23, 2024
4f931b2
Apply suggestions from code review
cr-xu Jul 23, 2024
111f32c
Apply suggestions from code review
cr-xu Jul 23, 2024
bddb7c2
Update cheetah/utils/bmadx.py
cr-xu Jul 23, 2024
c1e91ed
Merge branch 'master' into 139-add-quadrupole-with-chromatic-effects
jank324 Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -40,6 +40,6 @@
- [ ] I have run `pytest` on a machine with a CUDA GPU and made sure all tests pass (**required**).
- [ ] I have checked that the documentation builds (**required**).

Note: We are using a maximum length of 88 characters per line
Note: We are using a maximum length of 88 characters per line.

<!--- This Template is an edited version of the one from https://github.com/DLR-RM/stable-baselines3/ -->
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
- `Segment`s can now be imported from Bmad to devices other than `torch.device("cpu")` and dtypes other than `torch.float32` (see #196, #206) (@jank324)
- `Screen` now offers the option to use KDE for differentiable images (see #200) (@cr-xu, @roussel-ryan)
- Moving `Element`s and `Beam`s to a different `device` and changing their `dtype` like with any `torch.nn.Module` is now possible (see #209) (@jank324)
- `Quadrupole` now supports tracking with Cheetah's matrix-based method or with Bmad's more accurate method (see #153) (@jp-ga, @jank324)

### 🐛 Bug fixes

127 changes: 124 additions & 3 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
from typing import Optional, Union
from typing import Literal, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants
from torch import Size, nn

from cheetah.particles import Beam, ParticleBeam
from cheetah.track_methods import base_rmatrix, misalignment_matrix
from cheetah.utils import UniqueNameGenerator
from cheetah.utils import UniqueNameGenerator, bmadx

from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

electron_mass_eV = torch.tensor(
physical_constants["electron mass energy equivalent in MeV"][0] * 1e6
)


class Quadrupole(Element):
"""
@@ -23,6 +29,9 @@ class Quadrupole(Element):
:param misalignment: Misalignment vector of the quadrupole in x- and y-directions.
:param tilt: Tilt angle of the quadrupole in x-y plane [rad]. pi/4 for
skew-quadrupole.
:param num_steps: Number of drift-kick-drift steps to use for tracking through the
element when tracking method is set to `"bmadx"`.
:param tracking_method: Method to use for tracking through the element.
:param name: Unique identifier of the element.
"""

@@ -32,6 +41,8 @@ def __init__(
k1: Optional[Union[torch.Tensor, nn.Parameter]] = None,
misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None,
tilt: Optional[Union[torch.Tensor, nn.Parameter]] = None,
num_steps: int = 1,
tracking_method: Literal["cheetah", "bmadx"] = "cheetah",
name: Optional[str] = None,
device=None,
dtype=torch.float32,
@@ -64,6 +75,8 @@ def __init__(
else torch.zeros_like(self.length)
),
)
self.num_steps = num_steps
self.tracking_method = tracking_method

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
R = base_rmatrix(
@@ -81,6 +94,110 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry)
return R

def track(self, incoming: Beam) -> Beam:
"""
Track particles through the quadrupole element.
:param incoming: Beam entering the element.
:return: Beam exiting the element.
"""
if self.tracking_method == "cheetah":
return super().track(incoming)
elif self.tracking_method == "bmadx":
assert isinstance(
incoming, ParticleBeam
), "Bmad-X tracking is currently only supported for `ParticleBeam`."
return self._track_bmadx(incoming)
else:
raise ValueError(
f"Invalid tracking method {self.tracking_method}. "
+ "Supported methods are 'cheetah' and 'bmadx'."
)

def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
"""
Track particles through the quadrupole element using the Bmad-X tracking method.
:param incoming: Beam entering the element. Currently only supports
`ParticleBeam`.
:return: Beam exiting the element.
"""
# Compute Bmad coordinates and p0c
mc2 = electron_mass_eV.to(
device=incoming.particles.device, dtype=incoming.particles.dtype
)
bmad_coords, p0c = bmadx.cheetah_to_bmad_coords(
incoming.particles, incoming.energy, mc2
)
x = bmad_coords[..., 0]
px = bmad_coords[..., 1]
y = bmad_coords[..., 2]
py = bmad_coords[..., 3]
z = bmad_coords[..., 4]
pz = bmad_coords[..., 5]

x_offset = self.misalignment[..., 0]
y_offset = self.misalignment[..., 1]

step_length = self.length / self.num_steps
b1 = self.k1 * self.length

# Begin Bmad-X tracking
x, px, y, py = bmadx.offset_particle_set(
x_offset, y_offset, self.tilt, x, px, y, py
)

for _ in range(self.num_steps):
rel_p = 1 + pz # Particle's relative momentum (P/P0)
k1 = b1.unsqueeze(-1) / (self.length.unsqueeze(-1) * rel_p)

tx, dzx = bmadx.calculate_quadrupole_coefficients(-k1, step_length, rel_p)
ty, dzy = bmadx.calculate_quadrupole_coefficients(k1, step_length, rel_p)

z = (
z
+ dzx[0] * x**2
+ dzx[1] * x * px
+ dzx[2] * px**2
+ dzy[0] * y**2
+ dzy[1] * y * py
+ dzy[2] * py**2
)

x_next = tx[0][0] * x + tx[0][1] * px
px_next = tx[1][0] * x + tx[1][1] * px
y_next = ty[0][0] * y + ty[0][1] * py
py_next = ty[1][0] * y + ty[1][1] * py

x, px, y, py = x_next, px_next, y_next, py_next

z = z + bmadx.low_energy_z_correction(pz, p0c, mc2, step_length)

# s = s + l
x, px, y, py = bmadx.offset_particle_unset(
x_offset, y_offset, self.tilt, x, px, y, py
)

# End of Bmad-X tracking
bmad_coords[..., 0] = x
bmad_coords[..., 1] = px
bmad_coords[..., 2] = y
bmad_coords[..., 3] = py
bmad_coords[..., 4] = z
bmad_coords[..., 5] = pz

# Convert back to Cheetah coordinates
cheetah_coords, ref_energy = bmadx.bmad_to_cheetah_coords(bmad_coords, p0c, mc2)

outgoing_beam = ParticleBeam(
cheetah_coords,
ref_energy,
particle_charges=incoming.particle_charges,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)
return outgoing_beam

def broadcast(self, shape: Size) -> Element:
return self.__class__(
length=self.length.repeat(shape),
@@ -94,7 +211,7 @@ def broadcast(self, shape: Size) -> Element:

@property
def is_skippable(self) -> bool:
return True
return self.tracking_method == "cheetah"

@property
def is_active(self) -> bool:
@@ -109,6 +226,8 @@ def split(self, resolution: torch.Tensor) -> list[Element]:
self.k1,
misalignment=self.misalignment,
tilt=self.tilt,
num_steps=self.num_steps,
tracking_method=self.tracking_method,
dtype=self.length.dtype,
device=self.length.device,
)
@@ -134,5 +253,7 @@ def __repr__(self) -> str:
+ f"k1={repr(self.k1)}, "
+ f"misalignment={repr(self.misalignment)}, "
+ f"tilt={repr(self.tilt)}, "
+ f"num_steps={repr(self.num_steps)}, "
+ f"tracking_method={repr(self.tracking_method)}, "
+ f"name={repr(self.name)})"
)
1 change: 1 addition & 0 deletions cheetah/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import bmadx # noqa: F401
from .device import is_mps_available_and_functional # noqa: F401
from .kde import kde_histogram_1d, kde_histogram_2d # noqa: F401
from .unique_name_generator import UniqueNameGenerator # noqa: F401
216 changes: 216 additions & 0 deletions cheetah/utils/bmadx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import torch

double_precision_epsilon = torch.finfo(torch.float64).eps


def cheetah_to_bmad_coords(
cheetah_coords: torch.Tensor, ref_energy: torch.Tensor, mc2: torch.Tensor
) -> torch.Tensor:
"""
Transforms Cheetah coordinates to Bmad coordinates.
:param cheetah_coords: 7-dimensional particle vectors in Cheetah coordinates.
:param ref_energy: Reference energy in eV.
"""
# TODO This can probably be moved to the `ParticleBeam` class at some point

# Initialize Bmad coordinates
bmad_coords = cheetah_coords[..., :6].clone()

# Cheetah longitudinal coordinates
tau = cheetah_coords[..., 4]
delta = cheetah_coords[..., 5]

# Compute p0c and Bmad z, pz
p0c = torch.sqrt(ref_energy**2 - mc2**2)
energy = ref_energy.unsqueeze(-1) + delta * p0c.unsqueeze(-1)
p = torch.sqrt(energy**2 - mc2**2)
beta = p / energy
z = -beta * tau
pz = (p - p0c.unsqueeze(-1)) / p0c.unsqueeze(-1)

# Bmad coordinates
bmad_coords[..., 4] = z
bmad_coords[..., 5] = pz

return bmad_coords, p0c


def bmad_to_cheetah_coords(
bmad_coords: torch.Tensor, p0c: torch.Tensor, mc2: torch.Tensor
) -> torch.Tensor:
"""
Transforms Bmad coordinates to Cheetah coordinates.
:param bmad_coords: 6-dimensional particle vectors in Bmad coordinates.
:param p0c: Reference momentum in eV/c.
"""
# TODO This can probably be moved to the `ParticleBeam` class at some point

# Initialize Cheetah coordinates
cheetah_coords = torch.ones(
(*bmad_coords.shape[:-1], 7), dtype=bmad_coords.dtype, device=bmad_coords.device
)
cheetah_coords[..., :6] = bmad_coords.clone()

# Bmad longitudinal coordinates
z = bmad_coords[..., 4]
pz = bmad_coords[..., 5]

# Compute ref_energy and Cheetah tau, delta
ref_energy = torch.sqrt(p0c**2 + mc2**2)
p = (1 + pz) * p0c.unsqueeze(-1)
energy = torch.sqrt(p**2 + mc2**2)
beta = p / energy
tau = -z / beta
delta = (energy - ref_energy.unsqueeze(-1)) / p0c.unsqueeze(-1)

# Cheetah coordinates
cheetah_coords[..., 4] = tau
cheetah_coords[..., 5] = delta

return cheetah_coords, ref_energy


def offset_particle_set(
x_offset: torch.Tensor,
y_offset: torch.Tensor,
tilt: torch.Tensor,
x_lab: torch.Tensor,
px_lab: torch.Tensor,
y_lab: torch.Tensor,
py_lab: torch.Tensor,
) -> list[torch.Tensor]:
"""
Transforms particle coordinates from lab to element frame.
:param x_offset: Element x-coordinate offset.
:param y_offset: Element y-coordinate offset.
:param tilt: Tilt angle (rad).
:param x_lab: x-coordinate in lab frame.
:param px_lab: x-momentum in lab frame.
:param y_lab: y-coordinate in lab frame.
:param py_lab: y-momentum in lab frame.
:return: x, px, y, py coordinates in element frame.
"""
s = torch.sin(tilt)
c = torch.cos(tilt)
x_ele_int = x_lab - x_offset.unsqueeze(-1)
y_ele_int = y_lab - y_offset.unsqueeze(-1)
x_ele = x_ele_int * c.unsqueeze(-1) + y_ele_int * s.unsqueeze(-1)
y_ele = -x_ele_int * s.unsqueeze(-1) + y_ele_int * c.unsqueeze(-1)
px_ele = px_lab * c.unsqueeze(-1) + py_lab * s.unsqueeze(-1)
py_ele = -px_lab * s.unsqueeze(-1) + py_lab * c.unsqueeze(-1)

return x_ele, px_ele, y_ele, py_ele


def offset_particle_unset(
x_offset: torch.Tensor,
y_offset: torch.Tensor,
tilt: torch.Tensor,
x_ele: torch.Tensor,
px_ele: torch.Tensor,
y_ele: torch.Tensor,
py_ele: torch.Tensor,
) -> list[torch.Tensor]:
"""
Transforms particle coordinates from element to lab frame.
:param x_offset: Element x-coordinate offset.
:param y_offset: Element y-coordinate offset.
:param tilt: Tilt angle (rad).
:param x_ele: x-coordinate in element frame.
:param px_ele: x-momentum in element frame.
:param y_ele: y-coordinate in element frame.
:param py_ele: y-momentum in element frame.
:return: x, px, y, py coordinates in lab frame.
"""
s = torch.sin(tilt)
c = torch.cos(tilt)
x_lab_int = x_ele * c.unsqueeze(-1) - y_ele * s.unsqueeze(-1)
y_lab_int = x_ele * s.unsqueeze(-1) + y_ele * c.unsqueeze(-1)
x_lab = x_lab_int + x_offset.unsqueeze(-1)
y_lab = y_lab_int + y_offset.unsqueeze(-1)
px_lab = px_ele * c.unsqueeze(-1) - py_ele * s.unsqueeze(-1)
py_lab = px_ele * s.unsqueeze(-1) + py_ele * c.unsqueeze(-1)

return x_lab, px_lab, y_lab, py_lab


def low_energy_z_correction(
pz: torch.Tensor, p0c: torch.Tensor, mc2: torch.Tensor, ds: torch.Tensor
) -> torch.Tensor:
"""
Corrects the change in z-coordinate due to speed < c_light.
:param pz: Particle longitudinal momentum.
:param p0c: Reference particle momentum in eV.
:param mc2: Particle mass in eV.
:param ds: Drift length.
:return: dz=(ds-d_particle) + ds*(beta - beta_ref)/beta_ref
"""
beta = (
(1 + pz)
* p0c.unsqueeze(-1)
/ torch.sqrt(((1 + pz) * p0c.unsqueeze(-1)) ** 2 + mc2**2)
)
beta0 = p0c / torch.sqrt(p0c**2 + mc2**2)
e_tot = torch.sqrt(p0c**2 + mc2**2)

evaluation = mc2 * (beta0.unsqueeze(-1) * pz) ** 2
dz = ds.unsqueeze(-1) * pz * (
1
- 3 * (pz * beta0.unsqueeze(-1) ** 2) / 2
+ pz**2
* beta0.unsqueeze(-1) ** 2
* (2 * beta0.unsqueeze(-1) ** 2 - (mc2 / e_tot.unsqueeze(-1)) ** 2 / 2)
) * (mc2 / e_tot.unsqueeze(-1)) ** 2 * (evaluation < 3e-7 * e_tot.unsqueeze(-1)) + (
ds.unsqueeze(-1) * (beta - beta0.unsqueeze(-1)) / beta0.unsqueeze(-1)
) * (
evaluation >= 3e-7 * e_tot.unsqueeze(-1)
)

return dz


def calculate_quadrupole_coefficients(
k1: torch.Tensor,
length: torch.Tensor,
rel_p: torch.Tensor,
eps: float = double_precision_epsilon,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
Returns 2x2 transfer matrix elements aij and the coefficients to calculate the
change in z position.
NOTE: Accumulated error due to machine epsilon.
:param k1: Quadrupole strength (k1 > 0 ==> defocus).
:param length: Quadrupole length.
:param rel_p: Relative momentum P/P0.
:param eps: Machine precision epsilon, default to double precision.
:return: Tuple of transfer matrix elements and coefficients.
a11, a12, a21, a22: Transfer matrix elements.
c1, c2, c3: Second order derivatives of z such that
z = c1 * x_0^2 + c2 * x_0 * px_0 + c3 * px_0^2.
"""
# TODO: Revisit to fix accumulated error due to machine epsilon
sqrt_k = torch.sqrt(torch.absolute(k1) + eps)
sk_l = sqrt_k * length.unsqueeze(-1)

cx = torch.cos(sk_l) * (k1 <= 0) + torch.cosh(sk_l) * (k1 > 0)
sx = (torch.sin(sk_l) / (sqrt_k)) * (k1 <= 0) + (torch.sinh(sk_l) / (sqrt_k)) * (
k1 > 0
)

a11 = cx
a12 = sx / rel_p
a21 = k1 * sx * rel_p
a22 = cx

c1 = k1 * (-cx * sx + length.unsqueeze(-1)) / 4
c2 = -k1 * sx**2 / (2 * rel_p)
c3 = -(cx * sx + length.unsqueeze(-1)) / (4 * rel_p**2)

return [[a11, a12], [a21, a22]], [c1, c2, c3]
323 changes: 323 additions & 0 deletions dev/bmadx_tests/quad_test.ipynb

Large diffs are not rendered by default.

Binary file not shown.
Binary file not shown.
71 changes: 71 additions & 0 deletions tests/test_quadrupole.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch

from cheetah import Drift, ParameterBeam, ParticleBeam, Quadrupole, Segment
@@ -79,6 +80,9 @@ def test_quadrupole_with_misalignments_multiple_batch_dimension():


def test_tilted_quadrupole_batch():
"""
Test that a quadrupole with a tilt behaves as expected in vectorised mode.
"""
batch_shape = torch.Size([3])
incoming = ParticleBeam.from_parameters(
num_particles=torch.tensor(1000000),
@@ -105,6 +109,10 @@ def test_tilted_quadrupole_batch():


def test_tilted_quadrupole_multiple_batch_dimension():
"""
Test that a quadrupole with a tilt behaves as expected in vectorised mode with
multiple vectorisation dimensions.
"""
batch_shape = torch.Size([3, 2])
incoming = ParticleBeam.from_parameters(
num_particles=torch.tensor(10000),
@@ -124,3 +132,66 @@ def test_tilted_quadrupole_multiple_batch_dimension():
outgoing = segment(incoming)

assert torch.allclose(outgoing.particles[0, 0], outgoing.particles[0, 1])


def test_quadrupole_bmadx_tracking():
"""
Test that the results of tracking through a quadrupole with the `"bmadx"` tracking
method match the results from Bmad-X.
"""
incoming = torch.load("tests/resources/bmadx/quadrupole_incoming_beam.pt")
quadrupole = Quadrupole(
length=torch.tensor([1.0]),
k1=torch.tensor([10.0]),
misalignment=torch.tensor([[0.01, -0.02]]),
tilt=torch.tensor([0.5]),
num_steps=10,
tracking_method="bmadx",
dtype=torch.double,
)
segment = Segment(elements=[quadrupole])

# Run tracking
outgoing = segment.track(incoming)

# Load reference result computed with Bmad-X
bmadx_out_with_cheetah_coords = torch.load(
"tests/resources/bmadx/quadrupole_bmadx_out_with_cheetah_coords.pt"
)

assert torch.allclose(
outgoing.particles, bmadx_out_with_cheetah_coords, atol=1e-7, rtol=1e-7
)


@pytest.mark.parametrize("tracking_method", ["cheetah", "bmadx"])
def test_tracking_method_vectorization(tracking_method):
"""
Test that the quadruople vectorisation works correctly with both tracking methods.
Only checks the shapes, not the physical correctness of the results.
"""
quadrupole = Quadrupole(
length=torch.tensor([[0.2, 0.25], [0.3, 0.35], [0.4, 0.45]]),
k1=torch.tensor([[4.2, 4.2], [4.3, 4.3], [4.4, 4.4]]),
misalignment=torch.zeros((3, 2, 2)),
tilt=torch.zeros((3, 2)),
tracking_method=tracking_method,
)
incoming = ParticleBeam.from_parameters(
sigma_x=torch.tensor([[1e-5, 2e-5], [2e-5, 3e-5], [3e-5, 4e-5]])
)

outgoing = quadrupole.track(incoming)

assert outgoing.mu_x.shape == (3, 2)
assert outgoing.mu_px.shape == (3, 2)
assert outgoing.mu_y.shape == (3, 2)
assert outgoing.mu_py.shape == (3, 2)
assert outgoing.sigma_x.shape == (3, 2)
assert outgoing.sigma_px.shape == (3, 2)
assert outgoing.sigma_y.shape == (3, 2)
assert outgoing.sigma_py.shape == (3, 2)
assert outgoing.sigma_tau.shape == (3, 2)
assert outgoing.sigma_p.shape == (3, 2)
assert outgoing.energy.shape == (3, 2)
assert outgoing.total_charge.shape == (3, 2)