diff --git a/examples/fodo.py b/examples/fodo.py index 189dfab..223f084 100644 --- a/examples/fodo.py +++ b/examples/fodo.py @@ -41,13 +41,14 @@ def main(): ) # Create line with all elements line = BeamLine( + name="fodo_cell", line=[ drift1, quad1, drift2, quad2, drift3, - ] + ], ) # Serialize to YAML yaml_data = yaml.dump(line.model_dump(), default_flow_style=False) diff --git a/src/pals_schema/BaseElement.py b/src/pals_schema/BaseElement.py index 9b361b2..577e619 100644 --- a/src/pals_schema/BaseElement.py +++ b/src/pals_schema/BaseElement.py @@ -21,5 +21,8 @@ def model_dump(self, *args, **kwargs): name = elem_dict.pop("name", None) if name is None: raise ValueError("Element missing 'name' attribute") - data = [{name: elem_dict}] + # Return a dict {name: properties} rather than a single-item list + # This makes the serialized form a plain dict so it can be passed to + # constructors using keyword expansion (e.g., Model(**data)) + data = {name: elem_dict} return data diff --git a/src/pals_schema/BeamLine.py b/src/pals_schema/BeamLine.py index e32b125..29f035a 100644 --- a/src/pals_schema/BeamLine.py +++ b/src/pals_schema/BeamLine.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import ConfigDict, Field, model_validator from typing import Annotated, List, Literal, Union from pals_schema.BaseElement import BaseElement @@ -7,7 +7,7 @@ from pals_schema.QuadrupoleElement import QuadrupoleElement -class BeamLine(BaseModel): +class BeamLine(BaseElement): """A line of elements and/or other lines""" # Validate every time a new value is assigned to an attribute, @@ -29,55 +29,65 @@ class BeamLine(BaseModel): ] ] - @field_validator("line", mode="before") + @model_validator(mode="before") @classmethod - def parse_list_of_dicts(cls, value): - """This method inserts the key of the one-key dictionary into - the name attribute of the elements""" - if not isinstance(value, list): - raise TypeError("line must be a list") - - if value and isinstance(value[0], BaseModel): - # Already a list of models; nothing to do - return value - - # we expect a list of dicts or strings - elements = [] - for item_dict in value: - # an element is either a reference string to another element or a dict - if isinstance(item_dict, str): + def unpack_yaml_structure(cls, data): + # Handle the top-level one-key dict: unpack the line's name + if isinstance(data, dict) and len(data) == 1: + name, value = list(data.items())[0] + if not isinstance(value, dict): + raise TypeError( + f"Value for line key {name!r} must be a dict, but we got {value!r}" + ) + value["name"] = name + data = value + # Handle the 'line' field: unpack each element's name + if "line" not in data: + raise ValueError("'line' field is missing") + if not isinstance(data["line"], list): + raise TypeError("'line' must be a list") + new_line = [] + # Loop over all elements in the line + for item in data["line"]: + # An element can be a string that refers to another element + if isinstance(item, str): raise RuntimeError("Reference/alias elements not yet implemented") - - elif isinstance(item_dict, dict): - if not (isinstance(item_dict, dict) and len(item_dict) == 1): + # An element can be a dict + elif isinstance(item, dict): + if not (len(item) == 1): raise ValueError( - f"Each line element must be a dict with exactly one key, the name of the element, but we got: {item_dict!r}" + f"Each element must be a dict with exactly one key (the element's name), but we got {item!r}" ) - [(name, fields)] = item_dict.items() - + name, fields = list(item.items())[0] if not isinstance(fields, dict): - raise ValueError( - f"Value for element key '{name}' must be a dict (got {fields!r})" + raise TypeError( + f"Value for element key {name!r} must be a dict (the element's properties), but we got {fields!r}" ) - - # Insert the name into the fields dict - fields["name"] = name - elements.append(fields) - return elements + fields["name"] = name + new_line.append(fields) + # An element can be an instance of an existing model + elif isinstance(item, BaseElement): + # Nothing to do, keep the element as is + new_line.append(item) + else: + raise TypeError( + f"Value for element key {name!r} must be a reference string or a dict, but we got {item!r}" + ) + data["line"] = new_line + return data def model_dump(self, *args, **kwargs): """This makes sure the element name property is moved out and up to a one-key dictionary""" - # Use default dump for non-line fields + # Use base element dump first and return a dict {key: value}, where 'key' + # is the name of the line and 'value' is a dict with all other properties data = super().model_dump(*args, **kwargs) - - # Reformat 'line' field as list of single-key dicts + # Reformat 'line' field as list of element dicts new_line = [] for elem in self.line: - # Use custom dump for each line element - elem_dict = elem.model_dump(**kwargs)[0] + # Use custom dump for each line element, which now returns a dict + elem_dict = elem.model_dump(**kwargs) new_line.append(elem_dict) - - data["line"] = new_line + data[self.name]["line"] = new_line return data diff --git a/tests/test_schema.py b/tests/test_schema.py index 935078f..393bf5c 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -106,7 +106,7 @@ def test_QuadrupoleElement(): def test_BeamLine(): # Create first line with one base element element1 = BaseElement(name="element1") - line1 = BeamLine(line=[element1]) + line1 = BeamLine(name="line1", line=[element1]) assert line1.line == [element1] # Extend first line with one thick element element2 = ThickElement(name="element2", length=2.0) @@ -114,7 +114,7 @@ def test_BeamLine(): assert line1.line == [element1, element2] # Create second line with one drift element element3 = DriftElement(name="element3", length=3.0) - line2 = BeamLine(line=[element3]) + line2 = BeamLine(name="line2", line=[element3]) # Extend first line with second line line1.line.extend(line2.line) assert line1.line == [element1, element2, element3] @@ -126,7 +126,7 @@ def test_yaml(): # Create one thick element element2 = ThickElement(name="element2", length=2.0) # Create line with both elements - line = BeamLine(line=[element1, element2]) + line = BeamLine(name="line", line=[element1, element2]) # Serialize the BeamLine object to YAML yaml_data = yaml.dump(line.model_dump(), default_flow_style=False) print(f"\n{yaml_data}") @@ -151,7 +151,7 @@ def test_json(): # Create one thick element element2 = ThickElement(name="element2", length=2.0) # Create line with both elements - line = BeamLine(line=[element1, element2]) + line = BeamLine(name="line", line=[element1, element2]) # Serialize the BeamLine object to JSON json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2) print(f"\n{json_data}")