Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/fodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def main():
# Create line with all elements
line = Line(
line=[
drift1,
quad1,
drift2,
quad2,
drift3,
{drift1.name: drift1},
{quad1.name: quad1},
{drift2.name: drift2},
{quad2.name: quad2},
{drift3.name: drift3},
]
)
# Serialize to YAML
Expand All @@ -67,7 +67,7 @@ def main():
# Validate loaded data
assert line == loaded_line
# Serialize to JSON
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
json_data = json.dumps(line.model_dump(), indent=2)
print("Dumping JSON data...")
print(f"{json_data}")
# Write JSON data to file
Expand Down
12 changes: 11 additions & 1 deletion schema/BaseElement.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from pydantic import BaseModel, ConfigDict
from pydantic import (
BaseModel,
ConfigDict,
model_serializer,
)
from typing import Literal, Optional


Expand All @@ -14,3 +18,9 @@ class BaseElement(BaseModel):

# Unique element name
name: Optional[str] = None

@model_serializer
def _serialize(self):
data = self.__dict__.copy()
data.pop("name", None)
return data
Comment on lines +22 to +26
Copy link
Member Author

@EZoni EZoni Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ax3l

This (i.e., overloading the model serialization) is the only way I found so far to produce an output like

kind: Line
line:
- drift1:
    kind: Drift
    length: 0.25
- quad1:
    MagneticMultipoleP:
      Bn1: 1.0
    kind: Quadrupole
    length: 1.0
- drift2:
    kind: Drift
    length: 0.5
- quad2:
    MagneticMultipoleP:
      Bn1: -1.0
    kind: Quadrupole
    length: 1.0
- drift3:
    kind: Drift
    length: 0.5

instead of

kind: Line
line:
- drift1:
    kind: Drift
    length: 0.25
    name: drift1
- quad1:
    MagneticMultipoleP:
      Bn1: 1.0
    kind: Quadrupole
    length: 1.0
    name: quad1
- drift2:
    kind: Drift
    length: 0.5
    name: drift2
- quad2:
    MagneticMultipoleP:
      Bn1: -1.0
    kind: Quadrupole
    length: 1.0
    name: quad2
- drift3:
    kind: Drift
    length: 0.5
    name: drift3

with the name attributes printed out as well (hence repeated).

Do you see other ways to obtain this, other than overloading the model serialization?

I guess we would have to do something similar for the deserialization (through @model_validator(mode="before")?), although I have not been able to make that work just yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the custom deserialization is necessary if we expect that this

yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
loaded_line = Line(**yaml_data)

should lead to line and loaded_line being identical (which we currently assert in the tests and examples).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible alternative posted in #14 that does keep the API user-friendly :)

21 changes: 12 additions & 9 deletions schema/Line.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel, ConfigDict, Field
from typing import Annotated, List, Literal, Union
from typing import Annotated, Dict, List, Literal, Union

from schema.BaseElement import BaseElement
from schema.ThickElement import ThickElement
Expand All @@ -17,15 +17,18 @@ class Line(BaseModel):
kind: Literal["Line"] = "Line"

line: List[
Annotated[
Union[
BaseElement,
ThickElement,
DriftElement,
QuadrupoleElement,
"Line",
Dict[
str,
Annotated[
Union[
BaseElement,
ThickElement,
DriftElement,
QuadrupoleElement,
"Line",
],
Field(discriminator="kind"),
],
Field(discriminator="kind"),
]
]

Expand Down
20 changes: 12 additions & 8 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,22 @@ def test_QuadrupoleElement():
def test_Line():
# Create first line with one base element
element1 = BaseElement(name="element1")
line1 = Line(line=[element1])
assert line1.line == [element1]
line1 = Line(line=[{element1.name: element1}])
assert line1.line == [{element1.name: element1}]
# Extend first line with one thick element
element2 = ThickElement(name="element2", length=2.0)
line1.line.extend([element2])
assert line1.line == [element1, element2]
line1.line.extend([{element2.name: element2}])
assert line1.line == [{element1.name: element1}, {element2.name: element2}]
# Create second line with one drift element
element3 = DriftElement(name="element3", length=3.0)
line2 = Line(line=[element3])
line2 = Line(line=[{element3.name: element3}])
# Extend first line with second line
line1.line.extend(line2.line)
assert line1.line == [element1, element2, element3]
assert line1.line == [
{element1.name: element1},
{element2.name: element2},
{element3.name: element3},
]


def test_yaml():
Expand All @@ -124,7 +128,7 @@ def test_yaml():
# Create one thick element
element2 = ThickElement(name="element2", length=2.0)
# Create line with both elements
line = Line(line=[element1, element2])
line = Line(line=[{element1.name: element1}, {element2.name: element2}])
# Serialize the Line object to YAML
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
print(f"\n{yaml_data}")
Expand All @@ -149,7 +153,7 @@ def test_json():
# Create one thick element
element2 = ThickElement(name="element2", length=2.0)
# Create line with both elements
line = Line(line=[element1, element2])
line = Line(line=[{element1.name: element1}, {element2.name: element2}])
# Serialize the Line object to JSON
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
print(f"\n{json_data}")
Expand Down
Loading