diff --git a/examples/fodo.py b/examples/fodo.py index a90cbe3..33c9a69 100644 --- a/examples/fodo.py +++ b/examples/fodo.py @@ -1,6 +1,3 @@ -import json -import yaml - from pals import MagneticMultipoleParameters from pals import Drift from pals import Quadrupole @@ -45,34 +42,24 @@ def main(): drift3, ], ) + # Serialize to YAML - yaml_data = yaml.dump(line.model_dump(), default_flow_style=False) - print("Dumping YAML data...") - print(f"{yaml_data}") - # Write YAML data to file - yaml_file = "examples_fodo.yaml" - with open(yaml_file, "w") as file: - file.write(yaml_data) + yaml_file = "examples_fodo.pals.yaml" + line.to_file(yaml_file) + # Read YAML data from file - with open(yaml_file, "r") as file: - yaml_data = yaml.safe_load(file) - # Parse YAML data - loaded_line = BeamLine(**yaml_data) + loaded_line = BeamLine.from_file(yaml_file) + # Validate loaded data assert line == loaded_line + # Serialize to JSON - json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2) - print("Dumping JSON data...") - print(f"{json_data}") - # Write JSON data to file - json_file = "examples_fodo.json" - with open(json_file, "w") as file: - file.write(json_data) + json_file = "examples_fodo.pals.json" + line.to_file(json_file) + # Read JSON data from file - with open(json_file, "r") as file: - json_data = json.loads(file.read()) - # Parse JSON data - loaded_line = BeamLine(**json_data) + loaded_line = BeamLine.from_file(json_file) + # Validate loaded data assert line == loaded_line diff --git a/src/pals/functions.py b/src/pals/functions.py new file mode 100644 index 0000000..4186f6f --- /dev/null +++ b/src/pals/functions.py @@ -0,0 +1,82 @@ +"""Public, free-standing functions for PALS.""" + +import os + + +def inspect_file_extensions(filename: str): + """Attempt to strip two levels of file extensions to determine the schema. + + filename examples: fodo.pals.yaml, fodo.pals.json, ... + """ + file_noext, extension = os.path.splitext(filename) + file_noext_noext, extension_inner = os.path.splitext(file_noext) + + if extension_inner != ".pals": + raise RuntimeError( + f"inspect_file_extensions: No support for file {filename} with extension {extension}. " + f"PALS files must end in .pals.json or .pals.yaml or similar." + ) + + return { + "file_noext": file_noext, + "extension": extension, + "file_noext_noext": file_noext_noext, + "extension_inner": extension_inner, + } + + +def load_file_to_dict(filename: str) -> dict: + # Attempt to strip two levels of file extensions to determine the schema. + # Examples: fodo.pals.yaml, fodo.pals.json, ... + file_noext, extension, file_noext_noext, extension_inner = inspect_file_extensions( + filename + ).values() + + # examples: fodo.pals.yaml, fodo.pals.json + with open(filename, "r") as file: + if extension == ".json": + import json + + pals_data = json.loads(file.read()) + + elif extension == ".yaml": + import yaml + + pals_data = yaml.safe_load(file) + + # TODO: toml, xml + + else: + raise RuntimeError( + f"load_file_to_dict: No support for PALS file {filename} with extension {extension} yet." + ) + + return pals_data + + +def store_dict_to_file(filename: str, pals_dict: dict): + file_noext, extension, file_noext_noext, extension_inner = inspect_file_extensions( + filename + ).values() + + # examples: fodo.pals.yaml, fodo.pals.json + if extension == ".json": + import json + + json_data = json.dumps(pals_dict, sort_keys=True, indent=2) + with open(filename, "w") as file: + file.write(json_data) + + elif extension == ".yaml": + import yaml + + yaml_data = yaml.dump(pals_dict, default_flow_style=False) + with open(filename, "w") as file: + file.write(yaml_data) + + # TODO: toml, xml + + else: + raise RuntimeError( + f"store_dict_to_file: No support for PALS file {filename} with extension {extension} yet." + ) diff --git a/src/pals/kinds/BeamLine.py b/src/pals/kinds/BeamLine.py index 2df4fca..f7dd881 100644 --- a/src/pals/kinds/BeamLine.py +++ b/src/pals/kinds/BeamLine.py @@ -3,6 +3,7 @@ from .all_elements import get_all_elements_as_annotation from .mixin import BaseElement +from ..functions import load_file_to_dict, store_dict_to_file class BeamLine(BaseElement): @@ -25,3 +26,14 @@ def model_dump(self, *args, **kwargs): from pals.kinds.mixin.all_element_mixin import dump_element_list return dump_element_list(self, "line", *args, **kwargs) + + @staticmethod + def from_file(filename: str) -> "BeamLine": + """Load a BeamLine from a text file""" + pals_dict = load_file_to_dict(filename) + return BeamLine(**pals_dict) + + def to_file(self, filename: str): + """Save a BeamLine to a text file""" + pals_dict = self.model_dump() + store_dict_to_file(filename, pals_dict) diff --git a/tests/test_elements.py b/tests/test_elements.py index b2ee125..794b287 100644 --- a/tests/test_elements.py +++ b/tests/test_elements.py @@ -105,7 +105,7 @@ def test_Quadrupole(): assert element.ElectricMultipoleP.En2 == element_electric_multipole_En2 assert element.ElectricMultipoleP.Es2 == element_electric_multipole_Es2 assert element.ElectricMultipoleP.tilt2 == element_electric_multipole_tilt2 - # Serialize the BeamLine object to YAML + # Serialize the element to YAML yaml_data = yaml.dump(element.model_dump(), default_flow_style=False) print(f"\n{yaml_data}") diff --git a/tests/test_serialization.py b/tests/test_serialization.py index ff93889..b596784 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,6 +1,4 @@ -import json import os -import yaml import pals @@ -13,17 +11,10 @@ def test_yaml(): # Create line with both elements line = pals.BeamLine(name="line", line=[element1, element2]) # Serialize the BeamLine object to YAML - yaml_data = yaml.dump(line.model_dump(), default_flow_style=False) - print(f"\n{yaml_data}") - # Write the YAML data to a test file - test_file = "line.yaml" - with open(test_file, "w") as file: - file.write(yaml_data) + test_file = "line.pals.yaml" + line.to_file(test_file) # Read the YAML data from the test file - with open(test_file, "r") as file: - yaml_data = yaml.safe_load(file) - # Parse the YAML data back into a BeamLine object - loaded_line = pals.BeamLine(**yaml_data) + loaded_line = pals.BeamLine.from_file(test_file) # Remove the test file os.remove(test_file) # Validate loaded BeamLine object @@ -38,17 +29,10 @@ def test_json(): # Create line with both elements line = pals.BeamLine(name="line", line=[element1, element2]) # Serialize the BeamLine object to JSON - json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2) - print(f"\n{json_data}") - # Write the JSON data to a test file - test_file = "line.json" - with open(test_file, "w") as file: - file.write(json_data) + test_file = "line.pals.json" + line.to_file(test_file) # Read the JSON data from the test file - with open(test_file, "r") as file: - json_data = json.loads(file.read()) - # Parse the JSON data back into a BeamLine object - loaded_line = pals.BeamLine(**json_data) + loaded_line = pals.BeamLine.from_file(test_file) # Remove the test file os.remove(test_file) # Validate loaded BeamLine object @@ -224,21 +208,16 @@ def test_comprehensive_lattice(): ], ) - # Test serialization to YAML - yaml_data = yaml.dump(lattice.model_dump(), default_flow_style=False) - print(f"\nComprehensive lattice YAML:\n{yaml_data}") - # Write to temporary file - yaml_file = "comprehensive_lattice.yaml" - with open(yaml_file, "w") as file: - file.write(yaml_data) + yaml_file = "comprehensive_lattice.pals.yaml" + lattice.to_file(yaml_file) # Read back from file with open(yaml_file, "r") as file: - loaded_yaml_data = yaml.safe_load(file) + print(f"\nComprehensive lattice YAML:\n{file.read()}") # Deserialize back to Python object using Pydantic model logic - loaded_lattice = pals.BeamLine(**loaded_yaml_data) + loaded_lattice = pals.BeamLine.from_file(yaml_file) # Verify the loaded lattice has the correct structure and parameter groups assert len(loaded_lattice.line) == 31 # Should have 31 elements @@ -284,21 +263,16 @@ def test_comprehensive_lattice(): assert unionele_loaded.elements[1].kind == "Drift" assert unionele_loaded.elements[1].length == 0.1 - # Test serialization to JSON - json_data = json.dumps(lattice.model_dump(), sort_keys=True, indent=2) - print(f"\nComprehensive lattice JSON:\n{json_data}") - # Write to temporary file - json_file = "comprehensive_lattice.json" - with open(json_file, "w") as file: - file.write(json_data) + json_file = "comprehensive_lattice.pals.json" + lattice.to_file(json_file) # Read back from file with open(json_file, "r") as file: - loaded_json_data = json.loads(file.read()) + print(f"\nComprehensive lattice JSON:\n{file.read()}") # Deserialize back to Python object using Pydantic model logic - loaded_lattice_json = pals.BeamLine(**loaded_json_data) + loaded_lattice_json = pals.BeamLine.from_file(json_file) # Verify the loaded lattice has the correct structure and parameter groups assert len(loaded_lattice_json.line) == 31 # Should have 31 elements