diff --git a/examples/fodo.py b/examples/fodo.py index c2e6ed5..c54dba0 100644 --- a/examples/fodo.py +++ b/examples/fodo.py @@ -12,6 +12,7 @@ from schema.QuadrupoleElement import QuadrupoleElement from schema.Line import Line +from schema.ElementWrapper import ElementWrapper def main(): @@ -44,15 +45,15 @@ def main(): # Create line with all elements line = Line( line=[ - drift1, - quad1, - drift2, - quad2, - drift3, + ElementWrapper(element=drift1), + ElementWrapper(element=quad1), + ElementWrapper(element=drift2), + ElementWrapper(element=quad2), + ElementWrapper(element=drift3), ] ) # Serialize to YAML - yaml_data = yaml.dump(line.model_dump(), default_flow_style=False) + yaml_data = yaml.dump(line.to_dict(), default_flow_style=False) print("Dumping YAML data...") print(f"{yaml_data}") # Write YAML data to file @@ -63,11 +64,11 @@ def main(): with open(yaml_file, "r") as file: yaml_data = yaml.safe_load(file) # Parse YAML data - loaded_line = Line(**yaml_data) + loaded_line = Line.from_dict(yaml_data) # Validate loaded data assert line == loaded_line # Serialize to JSON - json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2) + json_data = json.dumps(line.to_dict(), sort_keys=True, indent=2) print("Dumping JSON data...") print(f"{json_data}") # Write JSON data to file @@ -78,7 +79,7 @@ def main(): with open(json_file, "r") as file: json_data = json.loads(file.read()) # Parse JSON data - loaded_line = Line(**json_data) + loaded_line = Line.from_dict(json_data) # Validate loaded data assert line == loaded_line diff --git a/schema/ElementWrapper.py b/schema/ElementWrapper.py new file mode 100644 index 0000000..2f55d96 --- /dev/null +++ b/schema/ElementWrapper.py @@ -0,0 +1,41 @@ +from pydantic import BaseModel, Field +from typing import Annotated, Literal, Union +from schema.BaseElement import BaseElement +from schema.ThickElement import ThickElement +from schema.DriftElement import DriftElement +from schema.QuadrupoleElement import QuadrupoleElement + + +class ElementWrapper(BaseModel): + """Base class for element wrappers""" + + kind: Literal["ElementWrapper"] = "ElementWrapper" + + element: Annotated[ + Union[ + BaseElement, + ThickElement, + DriftElement, + QuadrupoleElement, + ], + Field(discriminator="kind"), + ] + + def to_dict(self): + element_dict = self.element.model_dump() + kind = element_dict.pop("kind") + return {kind: element_dict} + + @classmethod + def from_dict(cls, data): + kind = next(iter(data)) + element_data = data[kind] + element_data["kind"] = kind + element_class = { + "BaseElement": BaseElement, + "ThickElement": ThickElement, + "Drift": DriftElement, + "Quadrupole": QuadrupoleElement, + }[kind] + element = element_class(**element_data) + return cls(element=element) diff --git a/schema/Line.py b/schema/Line.py index e1ad070..3116d80 100644 --- a/schema/Line.py +++ b/schema/Line.py @@ -1,10 +1,7 @@ -from pydantic import BaseModel, ConfigDict, Field -from typing import Annotated, List, Literal, Union +from pydantic import BaseModel, ConfigDict +from typing import List, Literal -from schema.BaseElement import BaseElement -from schema.ThickElement import ThickElement -from schema.DriftElement import DriftElement -from schema.QuadrupoleElement import QuadrupoleElement +from schema.ElementWrapper import ElementWrapper class Line(BaseModel): @@ -16,18 +13,15 @@ class Line(BaseModel): kind: Literal["Line"] = "Line" - line: List[ - Annotated[ - Union[ - BaseElement, - ThickElement, - DriftElement, - QuadrupoleElement, - "Line", - ], - Field(discriminator="kind"), - ] - ] + line: List[ElementWrapper] + + def to_dict(self): + return {"kind": self.kind, "line": [element.to_dict() for element in self.line]} + + @classmethod + def from_dict(cls, data): + line_elements = [ElementWrapper.from_dict(element) for element in data["line"]] + return cls(line=line_elements) # Avoid circular import issues diff --git a/tests/test_schema.py b/tests/test_schema.py index 3abdc8a..5391112 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -10,6 +10,7 @@ from schema.ThickElement import ThickElement from schema.DriftElement import DriftElement from schema.QuadrupoleElement import QuadrupoleElement +from schema.ElementWrapper import ElementWrapper from schema.Line import Line @@ -104,18 +105,23 @@ def test_QuadrupoleElement(): def test_Line(): # Create first line with one base element element1 = BaseElement(name="element1") - line1 = Line(line=[element1]) - assert line1.line == [element1] + element1_wrapper = ElementWrapper(element=element1) + line1 = Line(line=[element1_wrapper]) + assert [line.element for line in line1.line] == [element1] # Extend first line with one thick element element2 = ThickElement(name="element2", length=2.0) - line1.line.extend([element2]) - assert line1.line == [element1, element2] + element2_wrapper = ElementWrapper(element=element2) + line1.line.extend([element2_wrapper]) + assert line1.line == [element1_wrapper, element2_wrapper] + assert [line.element for line in line1.line] == [element1, element2] # Create second line with one drift element element3 = DriftElement(name="element3", length=3.0) - line2 = Line(line=[element3]) + element3_wrapper = ElementWrapper(element=element3) + line2 = Line(line=[element3_wrapper]) # Extend first line with second line line1.line.extend(line2.line) - assert line1.line == [element1, element2, element3] + assert line1.line == [element1_wrapper, element2_wrapper, element3_wrapper] + assert [line.element for line in line1.line] == [element1, element2, element3] def test_yaml(): @@ -124,9 +130,11 @@ def test_yaml(): # Create one thick element element2 = ThickElement(name="element2", length=2.0) # Create line with both elements - line = Line(line=[element1, element2]) + line = Line( + line=[ElementWrapper(element=element1), ElementWrapper(element=element2)] + ) # Serialize the Line object to YAML - yaml_data = yaml.dump(line.model_dump(), default_flow_style=False) + yaml_data = yaml.dump(line.to_dict(), default_flow_style=False) print(f"\n{yaml_data}") # Write the YAML data to a test file test_file = "line.yaml" @@ -136,7 +144,7 @@ def test_yaml(): with open(test_file, "r") as file: yaml_data = yaml.safe_load(file) # Parse the YAML data back into a Line object - loaded_line = Line(**yaml_data) + loaded_line = Line.from_dict(yaml_data) # Remove the test file os.remove(test_file) # Validate loaded Line object @@ -149,9 +157,11 @@ def test_json(): # Create one thick element element2 = ThickElement(name="element2", length=2.0) # Create line with both elements - line = Line(line=[element1, element2]) + line = Line( + line=[ElementWrapper(element=element1), ElementWrapper(element=element2)] + ) # Serialize the Line object to JSON - json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2) + json_data = json.dumps(line.to_dict(), sort_keys=True, indent=2) print(f"\n{json_data}") # Write the JSON data to a test file test_file = "line.json" @@ -161,7 +171,7 @@ def test_json(): with open(test_file, "r") as file: json_data = json.loads(file.read()) # Parse the JSON data back into a Line object - loaded_line = Line(**json_data) + loaded_line = Line.from_dict(json_data) # Remove the test file os.remove(test_file) # Validate loaded Line object