Skip to content

Commit 2f6a332

Browse files
authored
Implement Bmad conversion device passing (#196)
* Implement test for Bmad conversion device passing * Fix Bmad device passing * Add Bmad device passing to changelog
1 parent 3730f93 commit 2f6a332

File tree

5 files changed

+98
-30
lines changed

5 files changed

+98
-30
lines changed

.vscode/settings.json

+12-16
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
{
2-
"spellright.language": [
3-
"en_GB"
4-
],
5-
"spellright.documentTypes": [
6-
"markdown",
7-
"latex",
8-
"plaintext"
9-
],
10-
"python.linting.flake8Enabled": true,
11-
"python.linting.enabled": true,
12-
"[python]": {
13-
"editor.defaultFormatter": "ms-python.black-formatter"
14-
},
15-
"python.formatting.provider": "none",
16-
"esbonio.sphinx.confDir": ""
17-
}
2+
"spellright.language": ["en_GB"],
3+
"spellright.documentTypes": ["markdown", "latex", "plaintext"],
4+
"python.linting.flake8Enabled": true,
5+
"python.linting.enabled": true,
6+
"[python]": {
7+
"editor.defaultFormatter": "ms-python.black-formatter"
8+
},
9+
"python.formatting.provider": "none",
10+
"esbonio.sphinx.confDir": "",
11+
"python.testing.unittestEnabled": false,
12+
"python.testing.pytestEnabled": true
13+
}

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
- Add a new class method for `ParticleBeam` to generate a 3D uniformly distributed ellipsoidal beam (see #146) (@cr-xu, @jank324)
1313
- Add Python 3.12 support (see #161) (@jank324)
1414
- Implement space charge using Green's function in a `SpaceChargeKick` element (see #142) (@greglenerd, @RemiLehe, @ax3l, @cr-xu, @jank324)
15+
- `Segment`s can now be imported from Bmad to devices other than `torch.device("cpu")` (see #196) (@jank324)
1516

1617
### 🐛 Bug fixes
1718

cheetah/accelerator/segment.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,10 @@ def from_ocelot(
272272

273273
@classmethod
274274
def from_bmad(
275-
cls, bmad_lattice_file_path: str, environment_variables: Optional[dict] = None
275+
cls,
276+
bmad_lattice_file_path: str,
277+
environment_variables: Optional[dict] = None,
278+
device: Optional[Union[str, torch.device]] = None,
276279
) -> "Segment":
277280
"""
278281
Read a Cheetah segment from a Bmad lattice file.
@@ -285,10 +288,13 @@ def from_bmad(
285288
:param bmad_lattice_file_path: Path to the Bmad lattice file.
286289
:param environment_variables: Dictionary of environment variables to use when
287290
parsing the lattice file.
291+
:param device: Device to place the lattice elements on.
288292
:return: Cheetah `Segment` representing the Bmad lattice.
289293
"""
290294
bmad_lattice_file_path = Path(bmad_lattice_file_path)
291-
return convert_bmad_lattice(bmad_lattice_file_path, environment_variables)
295+
return convert_bmad_lattice(
296+
bmad_lattice_file_path, environment_variables, device
297+
)
292298

293299
@classmethod
294300
def from_nx_tables(cls, filepath: Union[Path, str]) -> "Element":

cheetah/converters/bmad.py

+43-12
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
from copy import deepcopy
55
from pathlib import Path
6-
from typing import Any, Optional
6+
from typing import Any, Optional, Union
77

88
import numpy as np
99
import scipy
@@ -430,11 +430,15 @@ def validate_understood_properties(understood: list[str], properties: dict) -> N
430430
)
431431

432432

433-
def convert_element(name: str, context: dict) -> "cheetah.Element":
433+
def convert_element(
434+
name: str, context: dict, device: Optional[Union[str, torch.device]] = None
435+
) -> "cheetah.Element":
434436
"""Convert a parsed Bmad element dict to a cheetah Element.
435437
436438
:param name: Name of the (top-level) element to convert.
437439
:param context: Context dictionary parsed from Bmad lattice file(s).
440+
:param device: Device to put the element on. If `None`, the device is set to
441+
`torch.device("cpu")`.
438442
:return: Converted cheetah Element. If you are calling this function yourself
439443
as a user of Cheetah, this is most likely a `Segment`.
440444
"""
@@ -443,7 +447,8 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
443447
if isinstance(bmad_parsed, list):
444448
return cheetah.Segment(
445449
elements=[
446-
convert_element(element_name, context) for element_name in bmad_parsed
450+
convert_element(element_name, context, device)
451+
for element_name in bmad_parsed
447452
],
448453
name=name,
449454
)
@@ -466,27 +471,35 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
466471
["element_type", "alias", "type", "l"], bmad_parsed
467472
)
468473
if "l" in bmad_parsed:
469-
return cheetah.Drift(length=torch.tensor([bmad_parsed["l"]]), name=name)
474+
return cheetah.Drift(
475+
length=torch.tensor([bmad_parsed["l"]]), name=name, device=device
476+
)
470477
else:
471478
return cheetah.Marker(name=name)
472479
elif bmad_parsed["element_type"] == "instrument":
473480
validate_understood_properties(
474481
["element_type", "alias", "type", "l"], bmad_parsed
475482
)
476483
if "l" in bmad_parsed:
477-
return cheetah.Drift(length=torch.tensor([bmad_parsed["l"]]), name=name)
484+
return cheetah.Drift(
485+
length=torch.tensor([bmad_parsed["l"]]), name=name, device=device
486+
)
478487
else:
479488
return cheetah.Marker(name=name)
480489
elif bmad_parsed["element_type"] == "pipe":
481490
validate_understood_properties(
482491
["element_type", "alias", "type", "l", "descrip"], bmad_parsed
483492
)
484-
return cheetah.Drift(length=torch.tensor([bmad_parsed["l"]]), name=name)
493+
return cheetah.Drift(
494+
length=torch.tensor([bmad_parsed["l"]]), name=name, device=device
495+
)
485496
elif bmad_parsed["element_type"] == "drift":
486497
validate_understood_properties(
487498
["element_type", "l", "type", "descrip"], bmad_parsed
488499
)
489-
return cheetah.Drift(length=torch.tensor([bmad_parsed["l"]]), name=name)
500+
return cheetah.Drift(
501+
length=torch.tensor([bmad_parsed["l"]]), name=name, device=device
502+
)
490503
elif bmad_parsed["element_type"] == "hkicker":
491504
validate_understood_properties(
492505
["element_type", "type", "alias"], bmad_parsed
@@ -495,6 +508,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
495508
length=torch.tensor([bmad_parsed.get("l", 0.0)]),
496509
angle=torch.tensor([bmad_parsed.get("kick", 0.0)]),
497510
name=name,
511+
device=device,
498512
)
499513
elif bmad_parsed["element_type"] == "vkicker":
500514
validate_understood_properties(
@@ -504,6 +518,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
504518
length=torch.tensor([bmad_parsed.get("l", 0.0)]),
505519
angle=torch.tensor([bmad_parsed.get("kick", 0.0)]),
506520
name=name,
521+
device=device,
507522
)
508523
elif bmad_parsed["element_type"] == "sbend":
509524
validate_understood_properties(
@@ -539,6 +554,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
539554
else None
540555
),
541556
name=name,
557+
device=device,
542558
)
543559
elif bmad_parsed["element_type"] == "quadrupole":
544560
# TODO: Aperture for quadrupoles?
@@ -551,6 +567,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
551567
k1=torch.tensor([bmad_parsed["k1"]]),
552568
tilt=torch.tensor([bmad_parsed.get("tilt", 0.0)]),
553569
name=name,
570+
device=device,
554571
)
555572
elif bmad_parsed["element_type"] == "solenoid":
556573
validate_understood_properties(
@@ -560,6 +577,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
560577
length=torch.tensor([bmad_parsed["l"]]),
561578
k=torch.tensor([bmad_parsed["ks"]]),
562579
name=name,
580+
device=device,
563581
)
564582
elif bmad_parsed["element_type"] == "lcavity":
565583
validate_understood_properties(
@@ -584,6 +602,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
584602
),
585603
frequency=torch.tensor([bmad_parsed["rf_frequency"]]),
586604
name=name,
605+
device=device,
587606
)
588607
elif bmad_parsed["element_type"] == "rcollimator":
589608
validate_understood_properties(
@@ -595,6 +614,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
595614
y_max=torch.tensor([bmad_parsed.get("y_limit", np.inf)]),
596615
shape="rectangular",
597616
name=name,
617+
device=device,
598618
)
599619
elif bmad_parsed["element_type"] == "ecollimator":
600620
validate_understood_properties(
@@ -606,6 +626,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
606626
y_max=torch.tensor([bmad_parsed.get("y_limit", np.inf)]),
607627
shape="elliptical",
608628
name=name,
629+
device=device,
609630
)
610631
elif bmad_parsed["element_type"] == "wiggler":
611632
validate_understood_properties(
@@ -622,12 +643,16 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
622643
],
623644
bmad_parsed,
624645
)
625-
return cheetah.Undulator(length=torch.tensor([bmad_parsed["l"]]), name=name)
646+
return cheetah.Undulator(
647+
length=torch.tensor([bmad_parsed["l"]]), name=name, device=device
648+
)
626649
elif bmad_parsed["element_type"] == "patch":
627650
# TODO: Does this need to be implemented in Cheetah in a more proper way?
628651
validate_understood_properties(["element_type", "tilt"], bmad_parsed)
629652
return cheetah.Drift(
630-
length=torch.tensor([bmad_parsed.get("l", 0.0)]), name=name
653+
length=torch.tensor([bmad_parsed.get("l", 0.0)]),
654+
name=name,
655+
device=device,
631656
)
632657
else:
633658
print(
@@ -636,14 +661,18 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
636661
)
637662
# TODO: Remove the length if by adding markers to Cheeath
638663
return cheetah.Drift(
639-
name=name, length=torch.tensor([bmad_parsed.get("l", 0.0)])
664+
name=name,
665+
length=torch.tensor([bmad_parsed.get("l", 0.0)]),
666+
device=device,
640667
)
641668
else:
642669
raise ValueError(f"Unknown Bmad element type for {name = }")
643670

644671

645672
def convert_bmad_lattice(
646-
bmad_lattice_file_path: Path, environment_variables: Optional[dict] = None
673+
bmad_lattice_file_path: Path,
674+
environment_variables: Optional[dict] = None,
675+
device: Optional[Union[str, torch.device]] = None,
647676
) -> "cheetah.Element":
648677
"""
649678
Convert a Bmad lattice file to a Cheetah `Segment`.
@@ -656,6 +685,8 @@ def convert_bmad_lattice(
656685
:param bmad_lattice_file_path: Path to the Bmad lattice file.
657686
:param environment_variables: Dictionary of environment variables to use when
658687
parsing the lattice file.
688+
:param device: Device to use for the lattice. If `None`, the device is set to
689+
`torch.device("cpu")`.
659690
:return: Cheetah `Segment` representing the Bmad lattice.
660691
"""
661692

@@ -693,4 +724,4 @@ def convert_bmad_lattice(
693724
context = parse_lines(merged_lines)
694725

695726
# Convert the parsed lattice info to Cheetah elements
696-
return convert_element(context["__use__"], context)
727+
return convert_element(context["__use__"], context, device)

tests/test_bmad_conversion.py

+34
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import torch
23

34
import cheetah
@@ -31,3 +32,36 @@ def test_bmad_tutorial():
3132
assert converted.b.e1 == correct.b.e1
3233
assert converted.q.length == correct.q.length
3334
assert converted.q.k1 == correct.q.k1
35+
36+
37+
@pytest.mark.parametrize(
38+
"device",
39+
[
40+
torch.device("cpu"),
41+
pytest.param(
42+
torch.device("cuda"),
43+
marks=pytest.mark.skipif(
44+
not torch.cuda.is_available(), reason="CUDA not available"
45+
),
46+
),
47+
pytest.param(
48+
torch.device("mps"),
49+
marks=pytest.mark.skipif(
50+
not torch.backends.mps.is_available(), reason="MPS not available"
51+
),
52+
),
53+
],
54+
)
55+
def test_device_passing(device: torch.device):
56+
"""Test that the device is passed correctly."""
57+
file_path = "tests/resources/bmad_tutorial_lattice.bmad"
58+
59+
# Convert the lattice while passing the device
60+
converted = cheetah.Segment.from_bmad(file_path, device=device)
61+
62+
# Check that the properties of the loaded elements are on the correct device
63+
assert converted.d.length.device.type == device.type
64+
assert converted.b.length.device.type == device.type
65+
assert converted.b.e1.device.type == device.type
66+
assert converted.q.length.device.type == device.type
67+
assert converted.q.k1.device.type == device.type

0 commit comments

Comments
 (0)