Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/fodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ def main():
)
# Create line with all elements
line = BeamLine(
name="fodo_cell",
line=[
drift1,
quad1,
drift2,
quad2,
drift3,
]
],
)
# Serialize to YAML
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
Expand Down
5 changes: 4 additions & 1 deletion src/pals_schema/BaseElement.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,8 @@ def model_dump(self, *args, **kwargs):
name = elem_dict.pop("name", None)
if name is None:
raise ValueError("Element missing 'name' attribute")
data = [{name: elem_dict}]
# Return a dict {name: properties} rather than a single-item list
# This makes the serialized form a plain dict so it can be passed to
# constructors using keyword expansion (e.g., Model(**data))
data = {name: elem_dict}
return data
86 changes: 48 additions & 38 deletions src/pals_schema/BeamLine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import ConfigDict, Field, model_validator
from typing import Annotated, List, Literal, Union

from pals_schema.BaseElement import BaseElement
Expand All @@ -7,7 +7,7 @@
from pals_schema.QuadrupoleElement import QuadrupoleElement


class BeamLine(BaseModel):
class BeamLine(BaseElement):
"""A line of elements and/or other lines"""

# Validate every time a new value is assigned to an attribute,
Expand All @@ -29,55 +29,65 @@ class BeamLine(BaseModel):
]
]

@field_validator("line", mode="before")
@model_validator(mode="before")
@classmethod
def parse_list_of_dicts(cls, value):
"""This method inserts the key of the one-key dictionary into
the name attribute of the elements"""
if not isinstance(value, list):
raise TypeError("line must be a list")

if value and isinstance(value[0], BaseModel):
# Already a list of models; nothing to do
return value

# we expect a list of dicts or strings
elements = []
for item_dict in value:
# an element is either a reference string to another element or a dict
if isinstance(item_dict, str):
def unpack_yaml_structure(cls, data):
# Handle the top-level one-key dict: unpack the line's name
if isinstance(data, dict) and len(data) == 1:
name, value = list(data.items())[0]
if not isinstance(value, dict):
raise TypeError(
f"Value for line key {name!r} must be a dict, but we got {value!r}"
)
value["name"] = name
data = value
# Handle the 'line' field: unpack each element's name
if "line" not in data:
raise ValueError("'line' field is missing")
if not isinstance(data["line"], list):
raise TypeError("'line' must be a list")
new_line = []
# Loop over all elements in the line
for item in data["line"]:
# An element can be a string that refers to another element
if isinstance(item, str):
raise RuntimeError("Reference/alias elements not yet implemented")

elif isinstance(item_dict, dict):
if not (isinstance(item_dict, dict) and len(item_dict) == 1):
# An element can be a dict
elif isinstance(item, dict):
if not (len(item) == 1):
raise ValueError(
f"Each line element must be a dict with exactly one key, the name of the element, but we got: {item_dict!r}"
f"Each element must be a dict with exactly one key (the element's name), but we got {item!r}"
)
[(name, fields)] = item_dict.items()

name, fields = list(item.items())[0]
if not isinstance(fields, dict):
raise ValueError(
f"Value for element key '{name}' must be a dict (got {fields!r})"
raise TypeError(
f"Value for element key {name!r} must be a dict (the element's properties), but we got {fields!r}"
)

# Insert the name into the fields dict
fields["name"] = name
elements.append(fields)
return elements
fields["name"] = name
new_line.append(fields)
# An element can be an instance of an existing model
elif isinstance(item, BaseElement):
# Nothing to do, keep the element as is
new_line.append(item)
else:
raise TypeError(
f"Value for element key {name!r} must be a reference string or a dict, but we got {item!r}"
)
data["line"] = new_line
return data

def model_dump(self, *args, **kwargs):
"""This makes sure the element name property is moved out and up to a one-key dictionary"""
# Use default dump for non-line fields
# Use base element dump first and return a dict {key: value}, where 'key'
# is the name of the line and 'value' is a dict with all other properties
data = super().model_dump(*args, **kwargs)

# Reformat 'line' field as list of single-key dicts
# Reformat 'line' field as list of element dicts
new_line = []
for elem in self.line:
# Use custom dump for each line element
elem_dict = elem.model_dump(**kwargs)[0]
# Use custom dump for each line element, which now returns a dict
elem_dict = elem.model_dump(**kwargs)
new_line.append(elem_dict)

data["line"] = new_line
data[self.name]["line"] = new_line
return data


Expand Down
8 changes: 4 additions & 4 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,15 @@ def test_QuadrupoleElement():
def test_BeamLine():
# Create first line with one base element
element1 = BaseElement(name="element1")
line1 = BeamLine(line=[element1])
line1 = BeamLine(name="line1", line=[element1])
assert line1.line == [element1]
# Extend first line with one thick element
element2 = ThickElement(name="element2", length=2.0)
line1.line.extend([element2])
assert line1.line == [element1, element2]
# Create second line with one drift element
element3 = DriftElement(name="element3", length=3.0)
line2 = BeamLine(line=[element3])
line2 = BeamLine(name="line2", line=[element3])
# Extend first line with second line
line1.line.extend(line2.line)
assert line1.line == [element1, element2, element3]
Expand All @@ -126,7 +126,7 @@ def test_yaml():
# Create one thick element
element2 = ThickElement(name="element2", length=2.0)
# Create line with both elements
line = BeamLine(line=[element1, element2])
line = BeamLine(name="line", line=[element1, element2])
# Serialize the BeamLine object to YAML
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
print(f"\n{yaml_data}")
Expand All @@ -151,7 +151,7 @@ def test_json():
# Create one thick element
element2 = ThickElement(name="element2", length=2.0)
# Create line with both elements
line = BeamLine(line=[element1, element2])
line = BeamLine(name="line", line=[element1, element2])
# Serialize the BeamLine object to JSON
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
print(f"\n{json_data}")
Expand Down