diff --git a/examples/fodo.py b/examples/fodo.py index c2e6ed5..e0ef218 100644 --- a/examples/fodo.py +++ b/examples/fodo.py @@ -44,11 +44,11 @@ def main(): # Create line with all elements line = Line( line=[ - drift1, - quad1, - drift2, - quad2, - drift3, + {drift1.name: drift1}, + {quad1.name: quad1}, + {drift2.name: drift2}, + {quad2.name: quad2}, + {drift3.name: drift3}, ] ) # Serialize to YAML @@ -67,7 +67,7 @@ def main(): # Validate loaded data assert line == loaded_line # Serialize to JSON - json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2) + json_data = json.dumps(line.model_dump(), indent=2) print("Dumping JSON data...") print(f"{json_data}") # Write JSON data to file diff --git a/schema/BaseElement.py b/schema/BaseElement.py index 825635e..f8c6bd5 100644 --- a/schema/BaseElement.py +++ b/schema/BaseElement.py @@ -1,4 +1,8 @@ -from pydantic import BaseModel, ConfigDict +from pydantic import ( + BaseModel, + ConfigDict, + model_serializer, +) from typing import Literal, Optional @@ -14,3 +18,9 @@ class BaseElement(BaseModel): # Unique element name name: Optional[str] = None + + @model_serializer + def _serialize(self): + data = self.__dict__.copy() + data.pop("name", None) + return data diff --git a/schema/Line.py b/schema/Line.py index e1ad070..7cb3d73 100644 --- a/schema/Line.py +++ b/schema/Line.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, ConfigDict, Field -from typing import Annotated, List, Literal, Union +from typing import Annotated, Dict, List, Literal, Union from schema.BaseElement import BaseElement from schema.ThickElement import ThickElement @@ -17,15 +17,18 @@ class Line(BaseModel): kind: Literal["Line"] = "Line" line: List[ - Annotated[ - Union[ - BaseElement, - ThickElement, - DriftElement, - QuadrupoleElement, - "Line", + Dict[ + str, + Annotated[ + Union[ + BaseElement, + ThickElement, + DriftElement, + QuadrupoleElement, + "Line", + ], + Field(discriminator="kind"), ], - Field(discriminator="kind"), ] ] diff --git a/tests/test_schema.py b/tests/test_schema.py index 3abdc8a..4497ab4 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -104,18 +104,22 @@ def test_QuadrupoleElement(): def test_Line(): # Create first line with one base element element1 = BaseElement(name="element1") - line1 = Line(line=[element1]) - assert line1.line == [element1] + line1 = Line(line=[{element1.name: element1}]) + assert line1.line == [{element1.name: element1}] # Extend first line with one thick element element2 = ThickElement(name="element2", length=2.0) - line1.line.extend([element2]) - assert line1.line == [element1, element2] + line1.line.extend([{element2.name: element2}]) + assert line1.line == [{element1.name: element1}, {element2.name: element2}] # Create second line with one drift element element3 = DriftElement(name="element3", length=3.0) - line2 = Line(line=[element3]) + line2 = Line(line=[{element3.name: element3}]) # Extend first line with second line line1.line.extend(line2.line) - assert line1.line == [element1, element2, element3] + assert line1.line == [ + {element1.name: element1}, + {element2.name: element2}, + {element3.name: element3}, + ] def test_yaml(): @@ -124,7 +128,7 @@ def test_yaml(): # Create one thick element element2 = ThickElement(name="element2", length=2.0) # Create line with both elements - line = Line(line=[element1, element2]) + line = Line(line=[{element1.name: element1}, {element2.name: element2}]) # Serialize the Line object to YAML yaml_data = yaml.dump(line.model_dump(), default_flow_style=False) print(f"\n{yaml_data}") @@ -149,7 +153,7 @@ def test_json(): # Create one thick element element2 = ThickElement(name="element2", length=2.0) # Create line with both elements - line = Line(line=[element1, element2]) + line = Line(line=[{element1.name: element1}, {element2.name: element2}]) # Serialize the Line object to JSON json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2) print(f"\n{json_data}")