Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions examples/fodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from schema.QuadrupoleElement import QuadrupoleElement

from schema.Line import Line
from schema.ElementWrapper import ElementWrapper


def main():
Expand Down Expand Up @@ -44,15 +45,15 @@ def main():
# Create line with all elements
line = Line(
line=[
drift1,
quad1,
drift2,
quad2,
drift3,
ElementWrapper(element=drift1),
ElementWrapper(element=quad1),
ElementWrapper(element=drift2),
ElementWrapper(element=quad2),
ElementWrapper(element=drift3),
]
)
# Serialize to YAML
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
yaml_data = yaml.dump(line.to_dict(), default_flow_style=False)
print("Dumping YAML data...")
print(f"{yaml_data}")
# Write YAML data to file
Expand All @@ -63,11 +64,11 @@ def main():
with open(yaml_file, "r") as file:
yaml_data = yaml.safe_load(file)
# Parse YAML data
loaded_line = Line(**yaml_data)
loaded_line = Line.from_dict(yaml_data)
# Validate loaded data
assert line == loaded_line
# Serialize to JSON
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
json_data = json.dumps(line.to_dict(), sort_keys=True, indent=2)
print("Dumping JSON data...")
print(f"{json_data}")
# Write JSON data to file
Expand All @@ -78,7 +79,7 @@ def main():
with open(json_file, "r") as file:
json_data = json.loads(file.read())
# Parse JSON data
loaded_line = Line(**json_data)
loaded_line = Line.from_dict(json_data)
# Validate loaded data
assert line == loaded_line

Expand Down
41 changes: 41 additions & 0 deletions schema/ElementWrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from pydantic import BaseModel, Field
from typing import Annotated, Literal, Union
from schema.BaseElement import BaseElement
from schema.ThickElement import ThickElement
from schema.DriftElement import DriftElement
from schema.QuadrupoleElement import QuadrupoleElement


class ElementWrapper(BaseModel):
"""Base class for element wrappers"""

kind: Literal["ElementWrapper"] = "ElementWrapper"

element: Annotated[
Union[
BaseElement,
ThickElement,
DriftElement,
QuadrupoleElement,
],
Field(discriminator="kind"),
]

def to_dict(self):
element_dict = self.element.model_dump()
kind = element_dict.pop("kind")
return {kind: element_dict}

@classmethod
def from_dict(cls, data):
kind = next(iter(data))
element_data = data[kind]
element_data["kind"] = kind
element_class = {
"BaseElement": BaseElement,
"ThickElement": ThickElement,
"Drift": DriftElement,
"Quadrupole": QuadrupoleElement,
}[kind]
element = element_class(**element_data)
return cls(element=element)
30 changes: 12 additions & 18 deletions schema/Line.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from pydantic import BaseModel, ConfigDict, Field
from typing import Annotated, List, Literal, Union
from pydantic import BaseModel, ConfigDict
from typing import List, Literal

from schema.BaseElement import BaseElement
from schema.ThickElement import ThickElement
from schema.DriftElement import DriftElement
from schema.QuadrupoleElement import QuadrupoleElement
from schema.ElementWrapper import ElementWrapper


class Line(BaseModel):
Expand All @@ -16,18 +13,15 @@ class Line(BaseModel):

kind: Literal["Line"] = "Line"

line: List[
Annotated[
Union[
BaseElement,
ThickElement,
DriftElement,
QuadrupoleElement,
"Line",
],
Field(discriminator="kind"),
]
]
line: List[ElementWrapper]

def to_dict(self):
return {"kind": self.kind, "line": [element.to_dict() for element in self.line]}

@classmethod
def from_dict(cls, data):
line_elements = [ElementWrapper.from_dict(element) for element in data["line"]]
return cls(line=line_elements)


# Avoid circular import issues
Expand Down
34 changes: 22 additions & 12 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from schema.ThickElement import ThickElement
from schema.DriftElement import DriftElement
from schema.QuadrupoleElement import QuadrupoleElement
from schema.ElementWrapper import ElementWrapper

from schema.Line import Line

Expand Down Expand Up @@ -104,18 +105,23 @@ def test_QuadrupoleElement():
def test_Line():
# Create first line with one base element
element1 = BaseElement(name="element1")
line1 = Line(line=[element1])
assert line1.line == [element1]
element1_wrapper = ElementWrapper(element=element1)
line1 = Line(line=[element1_wrapper])
assert [line.element for line in line1.line] == [element1]
# Extend first line with one thick element
element2 = ThickElement(name="element2", length=2.0)
line1.line.extend([element2])
assert line1.line == [element1, element2]
element2_wrapper = ElementWrapper(element=element2)
line1.line.extend([element2_wrapper])
assert line1.line == [element1_wrapper, element2_wrapper]
assert [line.element for line in line1.line] == [element1, element2]
# Create second line with one drift element
element3 = DriftElement(name="element3", length=3.0)
line2 = Line(line=[element3])
element3_wrapper = ElementWrapper(element=element3)
line2 = Line(line=[element3_wrapper])
# Extend first line with second line
line1.line.extend(line2.line)
assert line1.line == [element1, element2, element3]
assert line1.line == [element1_wrapper, element2_wrapper, element3_wrapper]
assert [line.element for line in line1.line] == [element1, element2, element3]


def test_yaml():
Expand All @@ -124,9 +130,11 @@ def test_yaml():
# Create one thick element
element2 = ThickElement(name="element2", length=2.0)
# Create line with both elements
line = Line(line=[element1, element2])
line = Line(
line=[ElementWrapper(element=element1), ElementWrapper(element=element2)]
)
# Serialize the Line object to YAML
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
yaml_data = yaml.dump(line.to_dict(), default_flow_style=False)
print(f"\n{yaml_data}")
# Write the YAML data to a test file
test_file = "line.yaml"
Expand All @@ -136,7 +144,7 @@ def test_yaml():
with open(test_file, "r") as file:
yaml_data = yaml.safe_load(file)
# Parse the YAML data back into a Line object
loaded_line = Line(**yaml_data)
loaded_line = Line.from_dict(yaml_data)
# Remove the test file
os.remove(test_file)
# Validate loaded Line object
Expand All @@ -149,9 +157,11 @@ def test_json():
# Create one thick element
element2 = ThickElement(name="element2", length=2.0)
# Create line with both elements
line = Line(line=[element1, element2])
line = Line(
line=[ElementWrapper(element=element1), ElementWrapper(element=element2)]
)
# Serialize the Line object to JSON
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
json_data = json.dumps(line.to_dict(), sort_keys=True, indent=2)
print(f"\n{json_data}")
# Write the JSON data to a test file
test_file = "line.json"
Expand All @@ -161,7 +171,7 @@ def test_json():
with open(test_file, "r") as file:
json_data = json.loads(file.read())
# Parse the JSON data back into a Line object
loaded_line = Line(**json_data)
loaded_line = Line.from_dict(json_data)
# Remove the test file
os.remove(test_file)
# Validate loaded Line object
Expand Down