3
3
import re
4
4
from copy import deepcopy
5
5
from pathlib import Path
6
- from typing import Any , Optional
6
+ from typing import Any , Optional , Union
7
7
8
8
import numpy as np
9
9
import scipy
@@ -430,11 +430,15 @@ def validate_understood_properties(understood: list[str], properties: dict) -> N
430
430
)
431
431
432
432
433
- def convert_element (name : str , context : dict ) -> "cheetah.Element" :
433
+ def convert_element (
434
+ name : str , context : dict , device : Optional [Union [str , torch .device ]] = None
435
+ ) -> "cheetah.Element" :
434
436
"""Convert a parsed Bmad element dict to a cheetah Element.
435
437
436
438
:param name: Name of the (top-level) element to convert.
437
439
:param context: Context dictionary parsed from Bmad lattice file(s).
440
+ :param device: Device to put the element on. If `None`, the device is set to
441
+ `torch.device("cpu")`.
438
442
:return: Converted cheetah Element. If you are calling this function yourself
439
443
as a user of Cheetah, this is most likely a `Segment`.
440
444
"""
@@ -443,7 +447,8 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
443
447
if isinstance (bmad_parsed , list ):
444
448
return cheetah .Segment (
445
449
elements = [
446
- convert_element (element_name , context ) for element_name in bmad_parsed
450
+ convert_element (element_name , context , device )
451
+ for element_name in bmad_parsed
447
452
],
448
453
name = name ,
449
454
)
@@ -466,27 +471,35 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
466
471
["element_type" , "alias" , "type" , "l" ], bmad_parsed
467
472
)
468
473
if "l" in bmad_parsed :
469
- return cheetah .Drift (length = torch .tensor ([bmad_parsed ["l" ]]), name = name )
474
+ return cheetah .Drift (
475
+ length = torch .tensor ([bmad_parsed ["l" ]]), name = name , device = device
476
+ )
470
477
else :
471
478
return cheetah .Marker (name = name )
472
479
elif bmad_parsed ["element_type" ] == "instrument" :
473
480
validate_understood_properties (
474
481
["element_type" , "alias" , "type" , "l" ], bmad_parsed
475
482
)
476
483
if "l" in bmad_parsed :
477
- return cheetah .Drift (length = torch .tensor ([bmad_parsed ["l" ]]), name = name )
484
+ return cheetah .Drift (
485
+ length = torch .tensor ([bmad_parsed ["l" ]]), name = name , device = device
486
+ )
478
487
else :
479
488
return cheetah .Marker (name = name )
480
489
elif bmad_parsed ["element_type" ] == "pipe" :
481
490
validate_understood_properties (
482
491
["element_type" , "alias" , "type" , "l" , "descrip" ], bmad_parsed
483
492
)
484
- return cheetah .Drift (length = torch .tensor ([bmad_parsed ["l" ]]), name = name )
493
+ return cheetah .Drift (
494
+ length = torch .tensor ([bmad_parsed ["l" ]]), name = name , device = device
495
+ )
485
496
elif bmad_parsed ["element_type" ] == "drift" :
486
497
validate_understood_properties (
487
498
["element_type" , "l" , "type" , "descrip" ], bmad_parsed
488
499
)
489
- return cheetah .Drift (length = torch .tensor ([bmad_parsed ["l" ]]), name = name )
500
+ return cheetah .Drift (
501
+ length = torch .tensor ([bmad_parsed ["l" ]]), name = name , device = device
502
+ )
490
503
elif bmad_parsed ["element_type" ] == "hkicker" :
491
504
validate_understood_properties (
492
505
["element_type" , "type" , "alias" ], bmad_parsed
@@ -495,6 +508,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
495
508
length = torch .tensor ([bmad_parsed .get ("l" , 0.0 )]),
496
509
angle = torch .tensor ([bmad_parsed .get ("kick" , 0.0 )]),
497
510
name = name ,
511
+ device = device ,
498
512
)
499
513
elif bmad_parsed ["element_type" ] == "vkicker" :
500
514
validate_understood_properties (
@@ -504,6 +518,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
504
518
length = torch .tensor ([bmad_parsed .get ("l" , 0.0 )]),
505
519
angle = torch .tensor ([bmad_parsed .get ("kick" , 0.0 )]),
506
520
name = name ,
521
+ device = device ,
507
522
)
508
523
elif bmad_parsed ["element_type" ] == "sbend" :
509
524
validate_understood_properties (
@@ -539,6 +554,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
539
554
else None
540
555
),
541
556
name = name ,
557
+ device = device ,
542
558
)
543
559
elif bmad_parsed ["element_type" ] == "quadrupole" :
544
560
# TODO: Aperture for quadrupoles?
@@ -551,6 +567,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
551
567
k1 = torch .tensor ([bmad_parsed ["k1" ]]),
552
568
tilt = torch .tensor ([bmad_parsed .get ("tilt" , 0.0 )]),
553
569
name = name ,
570
+ device = device ,
554
571
)
555
572
elif bmad_parsed ["element_type" ] == "solenoid" :
556
573
validate_understood_properties (
@@ -560,6 +577,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
560
577
length = torch .tensor ([bmad_parsed ["l" ]]),
561
578
k = torch .tensor ([bmad_parsed ["ks" ]]),
562
579
name = name ,
580
+ device = device ,
563
581
)
564
582
elif bmad_parsed ["element_type" ] == "lcavity" :
565
583
validate_understood_properties (
@@ -584,6 +602,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
584
602
),
585
603
frequency = torch .tensor ([bmad_parsed ["rf_frequency" ]]),
586
604
name = name ,
605
+ device = device ,
587
606
)
588
607
elif bmad_parsed ["element_type" ] == "rcollimator" :
589
608
validate_understood_properties (
@@ -595,6 +614,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
595
614
y_max = torch .tensor ([bmad_parsed .get ("y_limit" , np .inf )]),
596
615
shape = "rectangular" ,
597
616
name = name ,
617
+ device = device ,
598
618
)
599
619
elif bmad_parsed ["element_type" ] == "ecollimator" :
600
620
validate_understood_properties (
@@ -606,6 +626,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
606
626
y_max = torch .tensor ([bmad_parsed .get ("y_limit" , np .inf )]),
607
627
shape = "elliptical" ,
608
628
name = name ,
629
+ device = device ,
609
630
)
610
631
elif bmad_parsed ["element_type" ] == "wiggler" :
611
632
validate_understood_properties (
@@ -622,12 +643,16 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
622
643
],
623
644
bmad_parsed ,
624
645
)
625
- return cheetah .Undulator (length = torch .tensor ([bmad_parsed ["l" ]]), name = name )
646
+ return cheetah .Undulator (
647
+ length = torch .tensor ([bmad_parsed ["l" ]]), name = name , device = device
648
+ )
626
649
elif bmad_parsed ["element_type" ] == "patch" :
627
650
# TODO: Does this need to be implemented in Cheetah in a more proper way?
628
651
validate_understood_properties (["element_type" , "tilt" ], bmad_parsed )
629
652
return cheetah .Drift (
630
- length = torch .tensor ([bmad_parsed .get ("l" , 0.0 )]), name = name
653
+ length = torch .tensor ([bmad_parsed .get ("l" , 0.0 )]),
654
+ name = name ,
655
+ device = device ,
631
656
)
632
657
else :
633
658
print (
@@ -636,14 +661,18 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
636
661
)
637
662
# TODO: Remove the length if by adding markers to Cheeath
638
663
return cheetah .Drift (
639
- name = name , length = torch .tensor ([bmad_parsed .get ("l" , 0.0 )])
664
+ name = name ,
665
+ length = torch .tensor ([bmad_parsed .get ("l" , 0.0 )]),
666
+ device = device ,
640
667
)
641
668
else :
642
669
raise ValueError (f"Unknown Bmad element type for { name = } " )
643
670
644
671
645
672
def convert_bmad_lattice (
646
- bmad_lattice_file_path : Path , environment_variables : Optional [dict ] = None
673
+ bmad_lattice_file_path : Path ,
674
+ environment_variables : Optional [dict ] = None ,
675
+ device : Optional [Union [str , torch .device ]] = None ,
647
676
) -> "cheetah.Element" :
648
677
"""
649
678
Convert a Bmad lattice file to a Cheetah `Segment`.
@@ -656,6 +685,8 @@ def convert_bmad_lattice(
656
685
:param bmad_lattice_file_path: Path to the Bmad lattice file.
657
686
:param environment_variables: Dictionary of environment variables to use when
658
687
parsing the lattice file.
688
+ :param device: Device to use for the lattice. If `None`, the device is set to
689
+ `torch.device("cpu")`.
659
690
:return: Cheetah `Segment` representing the Bmad lattice.
660
691
"""
661
692
@@ -693,4 +724,4 @@ def convert_bmad_lattice(
693
724
context = parse_lines (merged_lines )
694
725
695
726
# Convert the parsed lattice info to Cheetah elements
696
- return convert_element (context ["__use__" ], context )
727
+ return convert_element (context ["__use__" ], context , device )
0 commit comments