Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 12 additions & 25 deletions examples/fodo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import json
import yaml

from pals import MagneticMultipoleParameters
from pals import Drift
from pals import Quadrupole
Expand Down Expand Up @@ -45,34 +42,24 @@ def main():
drift3,
],
)

# Serialize to YAML
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
print("Dumping YAML data...")
print(f"{yaml_data}")
# Write YAML data to file
yaml_file = "examples_fodo.yaml"
with open(yaml_file, "w") as file:
file.write(yaml_data)
yaml_file = "examples_fodo.pals.yaml"
line.to_file(yaml_file)

# Read YAML data from file
with open(yaml_file, "r") as file:
yaml_data = yaml.safe_load(file)
# Parse YAML data
loaded_line = BeamLine(**yaml_data)
loaded_line = BeamLine.from_file(yaml_file)

# Validate loaded data
assert line == loaded_line

# Serialize to JSON
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
print("Dumping JSON data...")
print(f"{json_data}")
# Write JSON data to file
json_file = "examples_fodo.json"
with open(json_file, "w") as file:
file.write(json_data)
json_file = "examples_fodo.pals.json"
line.to_file(json_file)

# Read JSON data from file
with open(json_file, "r") as file:
json_data = json.loads(file.read())
# Parse JSON data
loaded_line = BeamLine(**json_data)
loaded_line = BeamLine.from_file(json_file)

# Validate loaded data
assert line == loaded_line

Expand Down
82 changes: 82 additions & 0 deletions src/pals/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Public, free-standing functions for PALS."""

import os


def inspect_file_extensions(filename: str):
"""Attempt to strip two levels of file extensions to determine the schema.

filename examples: fodo.pals.yaml, fodo.pals.json, ...
"""
file_noext, extension = os.path.splitext(filename)
file_noext_noext, extension_inner = os.path.splitext(file_noext)

if extension_inner != ".pals":
raise RuntimeError(
f"inspect_file_extensions: No support for file {filename} with extension {extension}. "
f"PALS files must end in .pals.json or .pals.yaml or similar."
)

return {
"file_noext": file_noext,
"extension": extension,
"file_noext_noext": file_noext_noext,
"extension_inner": extension_inner,
}


def load_file_to_dict(filename: str) -> dict:
# Attempt to strip two levels of file extensions to determine the schema.
# Examples: fodo.pals.yaml, fodo.pals.json, ...
file_noext, extension, file_noext_noext, extension_inner = inspect_file_extensions(
filename
).values()

# examples: fodo.pals.yaml, fodo.pals.json
with open(filename, "r") as file:
if extension == ".json":
import json

pals_data = json.loads(file.read())

elif extension == ".yaml":
import yaml

pals_data = yaml.safe_load(file)

# TODO: toml, xml

else:
raise RuntimeError(
f"load_file_to_dict: No support for PALS file {filename} with extension {extension} yet."
)

return pals_data


def store_dict_to_file(filename: str, pals_dict: dict):
file_noext, extension, file_noext_noext, extension_inner = inspect_file_extensions(
filename
).values()

# examples: fodo.pals.yaml, fodo.pals.json
if extension == ".json":
import json

json_data = json.dumps(pals_dict, sort_keys=True, indent=2)
with open(filename, "w") as file:
file.write(json_data)

elif extension == ".yaml":
import yaml

yaml_data = yaml.dump(pals_dict, default_flow_style=False)
with open(filename, "w") as file:
file.write(yaml_data)

# TODO: toml, xml

else:
raise RuntimeError(
f"store_dict_to_file: No support for PALS file {filename} with extension {extension} yet."
)
12 changes: 12 additions & 0 deletions src/pals/kinds/BeamLine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .all_elements import get_all_elements_as_annotation
from .mixin import BaseElement
from ..functions import load_file_to_dict, store_dict_to_file


class BeamLine(BaseElement):
Expand All @@ -25,3 +26,14 @@ def model_dump(self, *args, **kwargs):
from pals.kinds.mixin.all_element_mixin import dump_element_list

return dump_element_list(self, "line", *args, **kwargs)

@staticmethod
def from_file(filename: str) -> "BeamLine":
"""Load a BeamLine from a text file"""
pals_dict = load_file_to_dict(filename)
return BeamLine(**pals_dict)

def to_file(self, filename: str):
"""Save a BeamLine to a text file"""
pals_dict = self.model_dump()
store_dict_to_file(filename, pals_dict)
2 changes: 1 addition & 1 deletion tests/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_Quadrupole():
assert element.ElectricMultipoleP.En2 == element_electric_multipole_En2
assert element.ElectricMultipoleP.Es2 == element_electric_multipole_Es2
assert element.ElectricMultipoleP.tilt2 == element_electric_multipole_tilt2
# Serialize the BeamLine object to YAML
# Serialize the element to YAML
yaml_data = yaml.dump(element.model_dump(), default_flow_style=False)
print(f"\n{yaml_data}")

Expand Down
54 changes: 14 additions & 40 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import json
import os
import yaml

import pals

Expand All @@ -13,17 +11,10 @@ def test_yaml():
# Create line with both elements
line = pals.BeamLine(name="line", line=[element1, element2])
# Serialize the BeamLine object to YAML
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
print(f"\n{yaml_data}")
# Write the YAML data to a test file
test_file = "line.yaml"
with open(test_file, "w") as file:
file.write(yaml_data)
test_file = "line.pals.yaml"
line.to_file(test_file)
# Read the YAML data from the test file
with open(test_file, "r") as file:
yaml_data = yaml.safe_load(file)
# Parse the YAML data back into a BeamLine object
loaded_line = pals.BeamLine(**yaml_data)
loaded_line = pals.BeamLine.from_file(test_file)
# Remove the test file
os.remove(test_file)
# Validate loaded BeamLine object
Expand All @@ -38,17 +29,10 @@ def test_json():
# Create line with both elements
line = pals.BeamLine(name="line", line=[element1, element2])
# Serialize the BeamLine object to JSON
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
print(f"\n{json_data}")
# Write the JSON data to a test file
test_file = "line.json"
with open(test_file, "w") as file:
file.write(json_data)
test_file = "line.pals.json"
line.to_file(test_file)
# Read the JSON data from the test file
with open(test_file, "r") as file:
json_data = json.loads(file.read())
# Parse the JSON data back into a BeamLine object
loaded_line = pals.BeamLine(**json_data)
loaded_line = pals.BeamLine.from_file(test_file)
# Remove the test file
os.remove(test_file)
# Validate loaded BeamLine object
Expand Down Expand Up @@ -224,21 +208,16 @@ def test_comprehensive_lattice():
],
)

# Test serialization to YAML
yaml_data = yaml.dump(lattice.model_dump(), default_flow_style=False)
print(f"\nComprehensive lattice YAML:\n{yaml_data}")

# Write to temporary file
yaml_file = "comprehensive_lattice.yaml"
with open(yaml_file, "w") as file:
file.write(yaml_data)
yaml_file = "comprehensive_lattice.pals.yaml"
lattice.to_file(yaml_file)

# Read back from file
with open(yaml_file, "r") as file:
loaded_yaml_data = yaml.safe_load(file)
print(f"\nComprehensive lattice YAML:\n{file.read()}")

# Deserialize back to Python object using Pydantic model logic
loaded_lattice = pals.BeamLine(**loaded_yaml_data)
loaded_lattice = pals.BeamLine.from_file(yaml_file)

# Verify the loaded lattice has the correct structure and parameter groups
assert len(loaded_lattice.line) == 31 # Should have 31 elements
Expand Down Expand Up @@ -284,21 +263,16 @@ def test_comprehensive_lattice():
assert unionele_loaded.elements[1].kind == "Drift"
assert unionele_loaded.elements[1].length == 0.1

# Test serialization to JSON
json_data = json.dumps(lattice.model_dump(), sort_keys=True, indent=2)
print(f"\nComprehensive lattice JSON:\n{json_data}")

# Write to temporary file
json_file = "comprehensive_lattice.json"
with open(json_file, "w") as file:
file.write(json_data)
json_file = "comprehensive_lattice.pals.json"
lattice.to_file(json_file)

# Read back from file
with open(json_file, "r") as file:
loaded_json_data = json.loads(file.read())
print(f"\nComprehensive lattice JSON:\n{file.read()}")

# Deserialize back to Python object using Pydantic model logic
loaded_lattice_json = pals.BeamLine(**loaded_json_data)
loaded_lattice_json = pals.BeamLine.from_file(json_file)

# Verify the loaded lattice has the correct structure and parameter groups
assert len(loaded_lattice_json.line) == 31 # Should have 31 elements
Expand Down