Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Bmad conversion device passing #196

Merged
merged 3 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
{
"spellright.language": [
"en_GB"
],
"spellright.documentTypes": [
"markdown",
"latex",
"plaintext"
],
"python.linting.flake8Enabled": true,
"python.linting.enabled": true,
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none",
"esbonio.sphinx.confDir": ""
}
"spellright.language": ["en_GB"],
"spellright.documentTypes": ["markdown", "latex", "plaintext"],
"python.linting.flake8Enabled": true,
"python.linting.enabled": true,
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none",
"esbonio.sphinx.confDir": "",
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- Add a new class method for `ParticleBeam` to generate a 3D uniformly distributed ellipsoidal beam (see #146) (@cr-xu, @jank324)
- Add Python 3.12 support (see #161) (@jank324)
- Implement space charge using Green's function in a `SpaceChargeKick` element (see #142) (@greglenerd, @RemiLehe, @ax3l, @cr-xu, @jank324)
- `Segment`s can now be imported from Bmad to devices other than `torch.device("cpu")` (see #196) (@jank324)

### 🐛 Bug fixes

Expand Down
10 changes: 8 additions & 2 deletions cheetah/accelerator/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,10 @@ def from_ocelot(

@classmethod
def from_bmad(
cls, bmad_lattice_file_path: str, environment_variables: Optional[dict] = None
cls,
bmad_lattice_file_path: str,
environment_variables: Optional[dict] = None,
device: Optional[Union[str, torch.device]] = None,
) -> "Segment":
"""
Read a Cheetah segment from a Bmad lattice file.
Expand All @@ -285,10 +288,13 @@ def from_bmad(
:param bmad_lattice_file_path: Path to the Bmad lattice file.
:param environment_variables: Dictionary of environment variables to use when
parsing the lattice file.
:param device: Device to place the lattice elements on.
:return: Cheetah `Segment` representing the Bmad lattice.
"""
bmad_lattice_file_path = Path(bmad_lattice_file_path)
return convert_bmad_lattice(bmad_lattice_file_path, environment_variables)
return convert_bmad_lattice(
bmad_lattice_file_path, environment_variables, device
)

@classmethod
def from_nx_tables(cls, filepath: Union[Path, str]) -> "Element":
Expand Down
55 changes: 43 additions & 12 deletions cheetah/converters/bmad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from copy import deepcopy
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, Union

import numpy as np
import scipy
Expand Down Expand Up @@ -430,11 +430,15 @@ def validate_understood_properties(understood: list[str], properties: dict) -> N
)


def convert_element(name: str, context: dict) -> "cheetah.Element":
def convert_element(
name: str, context: dict, device: Optional[Union[str, torch.device]] = None
) -> "cheetah.Element":
"""Convert a parsed Bmad element dict to a cheetah Element.
:param name: Name of the (top-level) element to convert.
:param context: Context dictionary parsed from Bmad lattice file(s).
:param device: Device to put the element on. If `None`, the device is set to
`torch.device("cpu")`.
:return: Converted cheetah Element. If you are calling this function yourself
as a user of Cheetah, this is most likely a `Segment`.
"""
Expand All @@ -443,7 +447,8 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
if isinstance(bmad_parsed, list):
return cheetah.Segment(
elements=[
convert_element(element_name, context) for element_name in bmad_parsed
convert_element(element_name, context, device)
for element_name in bmad_parsed
],
name=name,
)
Expand All @@ -466,27 +471,35 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
["element_type", "alias", "type", "l"], bmad_parsed
)
if "l" in bmad_parsed:
return cheetah.Drift(length=torch.tensor([bmad_parsed["l"]]), name=name)
return cheetah.Drift(
length=torch.tensor([bmad_parsed["l"]]), name=name, device=device
)
else:
return cheetah.Marker(name=name)
elif bmad_parsed["element_type"] == "instrument":
validate_understood_properties(
["element_type", "alias", "type", "l"], bmad_parsed
)
if "l" in bmad_parsed:
return cheetah.Drift(length=torch.tensor([bmad_parsed["l"]]), name=name)
return cheetah.Drift(
length=torch.tensor([bmad_parsed["l"]]), name=name, device=device
)
else:
return cheetah.Marker(name=name)
elif bmad_parsed["element_type"] == "pipe":
validate_understood_properties(
["element_type", "alias", "type", "l", "descrip"], bmad_parsed
)
return cheetah.Drift(length=torch.tensor([bmad_parsed["l"]]), name=name)
return cheetah.Drift(
length=torch.tensor([bmad_parsed["l"]]), name=name, device=device
)
elif bmad_parsed["element_type"] == "drift":
validate_understood_properties(
["element_type", "l", "type", "descrip"], bmad_parsed
)
return cheetah.Drift(length=torch.tensor([bmad_parsed["l"]]), name=name)
return cheetah.Drift(
length=torch.tensor([bmad_parsed["l"]]), name=name, device=device
)
elif bmad_parsed["element_type"] == "hkicker":
validate_understood_properties(
["element_type", "type", "alias"], bmad_parsed
Expand All @@ -495,6 +508,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
length=torch.tensor([bmad_parsed.get("l", 0.0)]),
angle=torch.tensor([bmad_parsed.get("kick", 0.0)]),
name=name,
device=device,
)
elif bmad_parsed["element_type"] == "vkicker":
validate_understood_properties(
Expand All @@ -504,6 +518,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
length=torch.tensor([bmad_parsed.get("l", 0.0)]),
angle=torch.tensor([bmad_parsed.get("kick", 0.0)]),
name=name,
device=device,
)
elif bmad_parsed["element_type"] == "sbend":
validate_understood_properties(
Expand Down Expand Up @@ -539,6 +554,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
else None
),
name=name,
device=device,
)
elif bmad_parsed["element_type"] == "quadrupole":
# TODO: Aperture for quadrupoles?
Expand All @@ -551,6 +567,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
k1=torch.tensor([bmad_parsed["k1"]]),
tilt=torch.tensor([bmad_parsed.get("tilt", 0.0)]),
name=name,
device=device,
)
elif bmad_parsed["element_type"] == "solenoid":
validate_understood_properties(
Expand All @@ -560,6 +577,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
length=torch.tensor([bmad_parsed["l"]]),
k=torch.tensor([bmad_parsed["ks"]]),
name=name,
device=device,
)
elif bmad_parsed["element_type"] == "lcavity":
validate_understood_properties(
Expand All @@ -584,6 +602,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
),
frequency=torch.tensor([bmad_parsed["rf_frequency"]]),
name=name,
device=device,
)
elif bmad_parsed["element_type"] == "rcollimator":
validate_understood_properties(
Expand All @@ -595,6 +614,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
y_max=torch.tensor([bmad_parsed.get("y_limit", np.inf)]),
shape="rectangular",
name=name,
device=device,
)
elif bmad_parsed["element_type"] == "ecollimator":
validate_understood_properties(
Expand All @@ -606,6 +626,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
y_max=torch.tensor([bmad_parsed.get("y_limit", np.inf)]),
shape="elliptical",
name=name,
device=device,
)
elif bmad_parsed["element_type"] == "wiggler":
validate_understood_properties(
Expand All @@ -622,12 +643,16 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
],
bmad_parsed,
)
return cheetah.Undulator(length=torch.tensor([bmad_parsed["l"]]), name=name)
return cheetah.Undulator(
length=torch.tensor([bmad_parsed["l"]]), name=name, device=device
)
elif bmad_parsed["element_type"] == "patch":
# TODO: Does this need to be implemented in Cheetah in a more proper way?
validate_understood_properties(["element_type", "tilt"], bmad_parsed)
return cheetah.Drift(
length=torch.tensor([bmad_parsed.get("l", 0.0)]), name=name
length=torch.tensor([bmad_parsed.get("l", 0.0)]),
name=name,
device=device,
)
else:
print(
Expand All @@ -636,14 +661,18 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
)
# TODO: Remove the length if by adding markers to Cheeath
return cheetah.Drift(
name=name, length=torch.tensor([bmad_parsed.get("l", 0.0)])
name=name,
length=torch.tensor([bmad_parsed.get("l", 0.0)]),
device=device,
)
else:
raise ValueError(f"Unknown Bmad element type for {name = }")


def convert_bmad_lattice(
bmad_lattice_file_path: Path, environment_variables: Optional[dict] = None
bmad_lattice_file_path: Path,
environment_variables: Optional[dict] = None,
device: Optional[Union[str, torch.device]] = None,
) -> "cheetah.Element":
"""
Convert a Bmad lattice file to a Cheetah `Segment`.
Expand All @@ -656,6 +685,8 @@ def convert_bmad_lattice(
:param bmad_lattice_file_path: Path to the Bmad lattice file.
:param environment_variables: Dictionary of environment variables to use when
parsing the lattice file.
:param device: Device to use for the lattice. If `None`, the device is set to
`torch.device("cpu")`.
:return: Cheetah `Segment` representing the Bmad lattice.
"""

Expand Down Expand Up @@ -693,4 +724,4 @@ def convert_bmad_lattice(
context = parse_lines(merged_lines)

# Convert the parsed lattice info to Cheetah elements
return convert_element(context["__use__"], context)
return convert_element(context["__use__"], context, device)
34 changes: 34 additions & 0 deletions tests/test_bmad_conversion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch

import cheetah
Expand Down Expand Up @@ -31,3 +32,36 @@ def test_bmad_tutorial():
assert converted.b.e1 == correct.b.e1
assert converted.q.length == correct.q.length
assert converted.q.k1 == correct.q.k1


@pytest.mark.parametrize(
"device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
pytest.param(
torch.device("mps"),
marks=pytest.mark.skipif(
not torch.backends.mps.is_available(), reason="MPS not available"
),
),
],
)
def test_device_passing(device: torch.device):
"""Test that the device is passed correctly."""
file_path = "tests/resources/bmad_tutorial_lattice.bmad"

# Convert the lattice while passing the device
converted = cheetah.Segment.from_bmad(file_path, device=device)

# Check that the properties of the loaded elements are on the correct device
assert converted.d.length.device.type == device.type
assert converted.b.length.device.type == device.type
assert converted.b.e1.device.type == device.type
assert converted.q.length.device.type == device.type
assert converted.q.k1.device.type == device.type