diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 1e02c2a..7a9e2ff 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -1,4 +1,4 @@ -name: pals +name: pals-python on: push: @@ -7,7 +7,7 @@ on: pull_request: concurrency: - group: ${{ github.ref }}-${{ github.head_ref }}-pals + group: ${{ github.ref }}-${{ github.head_ref }}-pals-python cancel-in-progress: true jobs: @@ -27,3 +27,19 @@ jobs: - name: Test run: | pytest tests -v + examples: + name: examples + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Test + run: | + python examples/fodo.py diff --git a/.gitignore b/.gitignore index eed53e5..2750157 100644 --- a/.gitignore +++ b/.gitignore @@ -1,163 +1,4 @@ -# Byte-compiled / optimized / DLL files __pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ .pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ - -# temporary files from demo.py -line.pmad.* +.ruff_cache/ +*.swp diff --git a/README.md b/README.md index 6fd75d2..a036dbe 100644 --- a/README.md +++ b/README.md @@ -1,205 +1,73 @@ # Meet Your Python PALS -This is a Python implementation for the Particle Accelerator Lattice Standard (PALS). +This is a Python implementation for the Particle Accelerator Lattice Standard ([PALS](https://github.com/campa-consortium/pals)). -To define the PALS schema, [Pydantic](https://docs.pydantic.dev) is used to map to Python objects, perform automatic validation, and to (de)serialize data classes to/from many modern file formats. -Various modern file formats (e.g., YAML, JSON, TOML, XML, ...) are supported, which makes implementation of the schema-following creates files in any modern programming language easy (e.g., Python, Julia, LUA, C++, Javascript, ...). +To define the PALS schema, [Pydantic](https://docs.pydantic.dev) is used to map to Python objects, perform automatic validation, and serialize/deserialize data classes to/from many modern file formats. +Various modern file formats (e.g., YAML, JSON, TOML, XML, etc.) are supported, which makes the implementation of the schema-following files in any modern programming language easy (e.g., Python, Julia, C++, LUA, Javascript, etc.). Here, we do Python. ## Status -This project is a work-in-progress and evolves alongside the Particle Accelerator Lattice Standard (PALS) documents. +This project is a work-in-progress and evolves alongside the Particle Accelerator Lattice Standard ([PALS](https://github.com/campa-consortium/pals)). ## Approach -This project implements the PALS schema in a file agnostic way, mirrored in data objects. -The corresponding serialized files (and optionally, also the corresponding Python objects) can be human-written, human-read, and automatically be validated. +This project implements the PALS schema in a file-agnostic way, mirrored in data objects. +The corresponding serialized files (and optionally, also the corresponding Python objects) can be human-written, human-read, and automatically validated. PALS files follow a schema and readers can error out on issues. Not every PALS implementation needs to be as detailed as this reference implementation in Python. -Nonetheless, you can use this implementation to convert between differnt file formats (see above) or to validate a file before reading it with your favorite YAML/JSON/TOML/XML/... library in your programming language of choice. - -So let's go, let us use the element descriptions we love and do not spend time anymore on parsing differences between code conventions. +Nonetheless, you can use this implementation to convert between differnt file formats or to validate a file before reading it with your favorite YAML/JSON/TOML/XML/... library in your programming language of choice. This will enable us to: -- exchange lattices between codes -- use common GUIs for defining lattices -- use common lattice visualization tools (2D, 3D, etc.) +- exchange lattices between codes; +- use common GUIs for defining lattices; +- use common lattice visualization tools (2D, 3D, etc.). ### FAQ -*Why do you use Pydantic for this implementation?* +*Why use Pydantic for this implementation?* Implementing directly against a specific file format is possible, but cumbersome. -By using widely-used schema engine, we can get the "last" part, serialization and deserialization to various file formats (and converting between them, and validating them) for free. +By using a widely-used schema engine, such as [Pydantic](https://docs.pydantic.dev), we can get serialization/deserialization to/from various file formats, conversion, and validation "for free". ## Roadmap Preliminary roadmap: -1. Define the PALS schema, using Pydantic -2. Document the API well. -3. Reference implementation in Python -3.1. attract additional reference implementations in other languages. +1. Define the PALS schema, using Pydantic. +2. Document the API. +3. Reference implementation in Python. +3.1. Attract additional reference implementations in other languages. 4. Add supporting helpers, which can import existing MAD-X, Elegant, SXF files. -4.1. Try to be pretty feature complete in these importers (yeah, hard). -5. Implement readers in active community codes for beamline modeling. - Reuse the reference implementations: e.g., we will use this project for the [BLAST codes](https://blast.lbl.gov). - - -## Examples - -### YAML - -```yaml -line: -- ds: 1.0 - element: drift - nslice: 1 -- ds: 0.5 - element: sbend - nslice: 1 - rc: 5.0 -- ds: 0.0 - element: marker - name: midpoint -- line: - - ds: 1.0 - element: drift - nslice: 1 - - ds: 0.0 - element: marker - name: otherpoint - - ds: 0.5 - element: sbend - name: special-bend - nslice: 1 - rc: 5.0 - - ds: 0.5 - element: sbend - nslice: 1 - rc: 5.0 -``` - -### JSON - -```json -{ - "line": [ - { - "ds": 1.0, - "element": "drift", - "nslice": 1 - }, - { - "ds": 0.5, - "element": "sbend", - "nslice": 1, - "rc": 5.0 - }, - { - "ds": 0.0, - "element": "marker", - "name": "midpoint" - }, - { - "line": [ - { - "ds": 1.0, - "element": "drift", - "nslice": 1 - }, - { - "ds": 0.0, - "element": "marker", - "name": "otherpoint" - }, - { - "ds": 0.5, - "element": "sbend", - "name": "special-bend", - "nslice": 1, - "rc": 5.0 - }, - { - "ds": 0.5, - "element": "sbend", - "nslice": 1, - "rc": 5.0 - } - ] - } - ] -} -``` - -### Python Dictionary - -```py -{ - "line": [ - { - "ds": 1.0, - "element": "drift", - "nslice": 1 - }, - { - "ds": 0.5, - "element": "sbend", - "nslice": 1, - "rc": 5.0 - }, - { - "ds": 0.0, - "element": "marker", - "name": "midpoint" - }, - { - "line": [ - { - "ds": 1.0, - "element": "drift", - "nslice": 1 - }, - { - "ds": 0.0, - "element": "marker", - "name": "otherpoint" - }, - { - "ds": 0.5, - "element": "sbend", - "name": "special-bend", - "nslice": 1, - "rc": 5.0 - }, - { - "ds": 0.5, - "element": "sbend", - "nslice": 1, - "rc": 5.0 - } - ] - } - ] -} -``` - -### Python Dataclass Objects - -```py -line=[ - Drift(name=None, ds=1.0, nslice=1, element='drift'), - SBend(name=None, ds=0.5, nslice=1, element='sbend', rc=5.0), - Marker(name='midpoint', ds=0.0, element='marker'), - Line(line=[ - Drift(name=None, ds=1.0, nslice=1, element='drift'), - Marker(name='otherpoint', ds=0.0, element='marker'), - SBend(name='special-bend', ds=0.5, nslice=1, element='sbend', rc=5.0), - SBend(name=None, ds=0.5, nslice=1, element='sbend', rc=5.0) - ]) -] -``` +4.1. Try to be as feature complete as possible in these importers. +5. Reuse the reference implementations and implement readers in community codes for beamline modeling (e.g., the [BLAST codes](https://blast.lbl.gov)). + + +## How to run the tests and examples locally + +In order to run the tests and examples locally, please follow these steps: + +1. Create a conda environment from the `environment.yml` file: + ```bash + conda env create -f environment.yml + ``` +2. Activate the conda environment: + ```bash + conda activate pals-python + ``` + Please double check the environment name in the `environment.yml` file. +3. Run the tests locally: + ```bash + pytest tests -v + ``` + The command line option `-v` increases the verbosity of the output. + You can also use the command line option `-s` to display any test output directly in the console (useful for debugging). + Please refer to [pytest's documentation](https://docs.pytest.org/en/stable/) for further details on the available command line options and/or run `pytest --help`. +4. Run the examples locally (e.g., `fodo.py`): + ```bash + python examples/fodo.py + ``` diff --git a/demo.py b/demo.py deleted file mode 100755 index 4e1f82c..0000000 --- a/demo.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 -# -# -*- coding: utf-8 -*- - -import pydantic - -import json -import toml # Python 3.11 tomllib -import yaml - -from schema import Line -from schema.elements import Drift, SBend, Marker - -print(pydantic.__version__) - - -# define -line = Line( - line=[ - Drift(ds=1.0), - SBend(ds=0.5, rc=5), - Marker(name="midpoint"), - ] -) -line.line.extend( - [ - Line( - line=[ - Drift(ds=1.0), - Marker(name="otherpoint"), - SBend(ds=0.5, rc=5, name="special-bend"), - SBend(ds=0.5, rc=5), - ] - ) - ] -) - -# doc strings -print(SBend.__doc__) -# help(SBend) # more to explore for simplified output, like pydantic-autodoc in Sphinx - -# export -print(f"Python:\n{line}") -model = line.model_dump(mode="json", exclude_none=True) -print(f"JSON model:\n{model}") -model_py = line.model_dump(mode="python", exclude_none=True) -print(f"Python model:\n{model_py}") - - -model_json = json.dumps(line.model_dump(exclude_none=True), sort_keys=True, indent=4) -print(model_json) - -with open("line.pmad.json", "w") as out_file: - out_file.write(model_json) - - -# import -with open("line.pmad.json", "r") as in_file: - read_json_dict = json.loads(in_file.read()) - -print(read_json_dict) - -# validate -read_json_model = Line(**read_json_dict) -print(read_json_model) - -# ensures correctness in construction, read-from-file -# AND in interactive use -try: - Drift(ds=-1.0) # fails with: Input should be greater than 0 -except pydantic.ValidationError as e: - print(e) - -try: - d = Drift(ds=1.0) - d.ds = -1.0 # fails with: Input should be greater than 0 -except pydantic.ValidationError as e: - print(e) -print(d) - -# json schema file for validation outside of pydantic -with open("line.pmad.json.schema", "w") as out_file: - out_file.write(json.dumps(line.model_json_schema(), sort_keys=True, indent=4)) - - -# yaml! -# export -with open("line.pmad.yaml", "w") as out_file: - yaml.dump(line.model_dump(exclude_none=True), out_file) - - -# import -def read_yaml(file_path: str) -> dict: - with open(file_path, "r") as stream: - config = yaml.safe_load(stream) - - return Line(**config).model_dump() - - -read_yaml_dict = read_yaml("line.pmad.yaml") - -read_yaml_model = Line(**read_yaml_dict) -print(read_yaml_model) - - -# toml! (looks surprisingly ugly -.-) -# export -with open("line.pmad.toml", "w") as out_file: - toml.dump(line.model_dump(exclude_none=True), out_file) - -# import -with open("line.pmad.toml", "r") as in_file: - read_toml_dict = toml.load(in_file) - -read_toml_model = Line(**read_toml_dict) -print(read_toml_model) - -# XML: https://github.com/martinblech/xmltodict diff --git a/environment.yml b/environment.yml index 0fbe4a2..cb53eb5 100644 --- a/environment.yml +++ b/environment.yml @@ -7,6 +7,7 @@ channels: dependencies: - pre-commit - pydantic + - pytest - python - pyyaml - toml diff --git a/examples/fodo.py b/examples/fodo.py new file mode 100644 index 0000000..c2e6ed5 --- /dev/null +++ b/examples/fodo.py @@ -0,0 +1,87 @@ +import json +import os +import sys +import yaml + +# Add the parent directory to sys.path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from schema.MagneticMultipoleParameters import MagneticMultipoleParameters + +from schema.DriftElement import DriftElement +from schema.QuadrupoleElement import QuadrupoleElement + +from schema.Line import Line + + +def main(): + drift1 = DriftElement( + name="drift1", + length=0.25, + ) + quad1 = QuadrupoleElement( + name="quad1", + length=1.0, + MagneticMultipoleP=MagneticMultipoleParameters( + Bn1=1.0, + ), + ) + drift2 = DriftElement( + name="drift2", + length=0.5, + ) + quad2 = QuadrupoleElement( + name="quad2", + length=1.0, + MagneticMultipoleP=MagneticMultipoleParameters( + Bn1=-1.0, + ), + ) + drift3 = DriftElement( + name="drift3", + length=0.5, + ) + # Create line with all elements + line = Line( + line=[ + drift1, + quad1, + drift2, + quad2, + drift3, + ] + ) + # Serialize to YAML + yaml_data = yaml.dump(line.model_dump(), default_flow_style=False) + print("Dumping YAML data...") + print(f"{yaml_data}") + # Write YAML data to file + yaml_file = "examples_fodo.yaml" + with open(yaml_file, "w") as file: + file.write(yaml_data) + # Read YAML data from file + with open(yaml_file, "r") as file: + yaml_data = yaml.safe_load(file) + # Parse YAML data + loaded_line = Line(**yaml_data) + # Validate loaded data + assert line == loaded_line + # Serialize to JSON + json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2) + print("Dumping JSON data...") + print(f"{json_data}") + # Write JSON data to file + json_file = "examples_fodo.json" + with open(json_file, "w") as file: + file.write(json_data) + # Read JSON data from file + with open(json_file, "r") as file: + json_data = json.loads(file.read()) + # Parse JSON data + loaded_line = Line(**json_data) + # Validate loaded data + assert line == loaded_line + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b5d87fb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +pydantic +pytest +pyyaml +toml diff --git a/schema/BaseElement.py b/schema/BaseElement.py new file mode 100644 index 0000000..825635e --- /dev/null +++ b/schema/BaseElement.py @@ -0,0 +1,16 @@ +from pydantic import BaseModel, ConfigDict +from typing import Literal, Optional + + +class BaseElement(BaseModel): + """A custom base element defining common properties""" + + # Discriminator field + kind: Literal["BaseElement"] = "BaseElement" + + # Validate every time a new value is assigned to an attribute, + # not only when an instance of BaseElement is created + model_config = ConfigDict(validate_assignment=True) + + # Unique element name + name: Optional[str] = None diff --git a/schema/DriftElement.py b/schema/DriftElement.py new file mode 100644 index 0000000..ce1bc01 --- /dev/null +++ b/schema/DriftElement.py @@ -0,0 +1,10 @@ +from typing import Literal + +from .ThickElement import ThickElement + + +class DriftElement(ThickElement): + """A field free region""" + + # Discriminator field + kind: Literal["Drift"] = "Drift" diff --git a/schema/Line.py b/schema/Line.py index c94970f..e1ad070 100644 --- a/schema/Line.py +++ b/schema/Line.py @@ -1,24 +1,34 @@ -from pydantic import BaseModel, Field -from typing import List, Union +from pydantic import BaseModel, ConfigDict, Field +from typing import Annotated, List, Literal, Union -from .elements import Drift, SBend, Marker +from schema.BaseElement import BaseElement +from schema.ThickElement import ThickElement +from schema.DriftElement import DriftElement +from schema.QuadrupoleElement import QuadrupoleElement class Line(BaseModel): """A line of elements and/or other lines""" - line: List[Union[Drift, SBend, Marker, "Line"]] = Field( - ..., discriminator="element" - ) - """A list of elements and/or other lines""" + # Validate every time a new value is assigned to an attribute, + # not only when an instance of Line is created + model_config = ConfigDict(validate_assignment=True) - # Hints for pure Python usage - class Config: - validate_assignment = True + kind: Literal["Line"] = "Line" + line: List[ + Annotated[ + Union[ + BaseElement, + ThickElement, + DriftElement, + QuadrupoleElement, + "Line", + ], + Field(discriminator="kind"), + ] + ] -# Hints for pure Python usage -Line.update_forward_refs() -# TODO / Ideas -# - Validate the Element.name is, if set, unique in a Line (including nested lines). +# Avoid circular import issues +Line.model_rebuild() diff --git a/schema/MagneticMultipoleParameters.py b/schema/MagneticMultipoleParameters.py new file mode 100644 index 0000000..a1b7dd0 --- /dev/null +++ b/schema/MagneticMultipoleParameters.py @@ -0,0 +1,62 @@ +from pydantic import BaseModel, ConfigDict, model_validator +from typing import Any, Dict + + +class MagneticMultipoleParameters(BaseModel): + """Magnetic multipole parameters""" + + # Allow arbitrary fields + model_config = ConfigDict(extra="allow") + + # Custom validation of magnetic multipole order + def _validate_order(key_num, msg): + if key_num.isdigit(): + if key_num.startswith("0") and key_num != "0": + raise ValueError(msg) + else: + raise ValueError(msg) + + # Custom validation to be applied before standard validation + @model_validator(mode="before") + def validate(cls, values: Dict[str, Any]) -> Dict[str, Any]: + # loop over all attributes + for key in values: + # validate tilt parameters 'tiltN' + if key.startswith("tilt"): + key_num = key[4:] + msg = " ".join( + [ + f"Invalid tilt parameter: '{key}'.", + "Tilt parameter must be of the form 'tiltN', where 'N' is an integer.", + ] + ) + cls._validate_order(key_num, msg) + # validate normal component parameters 'BnN' + elif key.startswith("Bn"): + key_num = key[2:] + msg = " ".join( + [ + f"Invalid normal component parameter: '{key}'.", + "Normal component parameter must be of the form 'BnN', where 'N' is an integer.", + ] + ) + cls._validate_order(key_num, msg) + # validate skew component parameters 'BsN' + elif key.startswith("Bs"): + key_num = key[2:] + msg = " ".join( + [ + f"Invalid skew component parameter: '{key}'.", + "Skew component parameter must be of the form 'BsN', where 'N' is an integer.", + ] + ) + cls._validate_order(key_num, msg) + else: + msg = " ".join( + [ + f"Invalid magnetic multipole parameter: '{key}'.", + "Magnetic multipole parameters must be of the form 'tiltN', 'BnN', or 'BsN', where 'N' is an integer.", + ] + ) + raise ValueError(msg) + return values diff --git a/schema/QuadrupoleElement.py b/schema/QuadrupoleElement.py new file mode 100644 index 0000000..f778a3d --- /dev/null +++ b/schema/QuadrupoleElement.py @@ -0,0 +1,14 @@ +from typing import Literal + +from .ThickElement import ThickElement +from .MagneticMultipoleParameters import MagneticMultipoleParameters + + +class QuadrupoleElement(ThickElement): + """A quadrupole element""" + + # Discriminator field + kind: Literal["Quadrupole"] = "Quadrupole" + + # Magnetic multipole parameters + MagneticMultipoleP: MagneticMultipoleParameters diff --git a/schema/ThickElement.py b/schema/ThickElement.py new file mode 100644 index 0000000..9461041 --- /dev/null +++ b/schema/ThickElement.py @@ -0,0 +1,14 @@ +from typing import Annotated, Literal +from annotated_types import Gt + +from .BaseElement import BaseElement + + +class ThickElement(BaseElement): + """A thick base element with finite segment length""" + + # Discriminator field + kind: Literal["ThickElement"] = "ThickElement" + + # Segment length in meters (m) + length: Annotated[float, Gt(0)] diff --git a/schema/__init__.py b/schema/__init__.py index d97886b..e69de29 100644 --- a/schema/__init__.py +++ b/schema/__init__.py @@ -1 +0,0 @@ -from .Line import Line diff --git a/schema/elements/Drift.py b/schema/elements/Drift.py deleted file mode 100644 index f82b99c..0000000 --- a/schema/elements/Drift.py +++ /dev/null @@ -1,10 +0,0 @@ -from .base import Thick - -from typing import Literal - - -class Drift(Thick): - """A drift element""" - - element: Literal["drift"] = "drift" - """The element type""" diff --git a/schema/elements/Marker.py b/schema/elements/Marker.py deleted file mode 100644 index 6bb28eb..0000000 --- a/schema/elements/Marker.py +++ /dev/null @@ -1,13 +0,0 @@ -from .base import Thin - -from typing import Literal - - -class Marker(Thin): - """A marker of a position in a line""" - - element: Literal["marker"] = "marker" - """The element type""" - - name: str - """"A unique name for the element when placed in the line""" diff --git a/schema/elements/SBend.py b/schema/elements/SBend.py deleted file mode 100644 index a5a4e8a..0000000 --- a/schema/elements/SBend.py +++ /dev/null @@ -1,14 +0,0 @@ -from .base import Thick - -from annotated_types import Gt -from typing import Annotated, Literal - - -class SBend(Thick): - """An ideal sector bend.""" - - element: Literal["sbend"] = "sbend" - """The element type""" - - rc: Annotated[float, Gt(0)] - """Radius of curvature in m""" diff --git a/schema/elements/__init__.py b/schema/elements/__init__.py deleted file mode 100644 index 2788c90..0000000 --- a/schema/elements/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .Drift import Drift -from .SBend import SBend -from .Marker import Marker diff --git a/schema/elements/base/Element.py b/schema/elements/base/Element.py deleted file mode 100644 index 6247817..0000000 --- a/schema/elements/base/Element.py +++ /dev/null @@ -1,13 +0,0 @@ -from pydantic import BaseModel -from typing import Optional - - -class Element(BaseModel): - """A mix-in model for elements, defining common properties""" - - name: Optional[str] = None - """A unique name for the element when placed in the line""" - - # Hints for pure Python usage - class Config: - validate_assignment = True diff --git a/schema/elements/base/Thick.py b/schema/elements/base/Thick.py deleted file mode 100644 index 39b2775..0000000 --- a/schema/elements/base/Thick.py +++ /dev/null @@ -1,14 +0,0 @@ -from annotated_types import Gt -from typing import Annotated - -from .Element import Element - - -class Thick(Element): - """A mix-in model for elements with finite segment length""" - - ds: Annotated[float, Gt(0)] - """Segment length in m""" - - nslice: int = 1 - """Number of slices through the segment (might be numerics and not phyics, thus might be removed)""" diff --git a/schema/elements/base/Thin.py b/schema/elements/base/Thin.py deleted file mode 100644 index 1d5273f..0000000 --- a/schema/elements/base/Thin.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Literal - -from .Element import Element - - -class Thin(Element): - """A mix-in model for elements with finite segment length""" - - element: Literal["ds"] = 0.0 - """Segment length in m (thin elements are always zero)""" diff --git a/schema/elements/base/__init__.py b/schema/elements/base/__init__.py deleted file mode 100644 index 339b534..0000000 --- a/schema/elements/base/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .Element import Element -from .Thin import Thin -from .Thick import Thick diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..3abdc8a --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,168 @@ +import os +from pydantic import ValidationError + +import json +import yaml + +from schema.MagneticMultipoleParameters import MagneticMultipoleParameters + +from schema.BaseElement import BaseElement +from schema.ThickElement import ThickElement +from schema.DriftElement import DriftElement +from schema.QuadrupoleElement import QuadrupoleElement + +from schema.Line import Line + + +def test_BaseElement(): + # Create one base element with custom name + element_name = "base_element" + element = BaseElement(name=element_name) + assert element.name == element_name + + +def test_ThickElement(): + # Create one thick element with custom name and length + element_name = "thick_element" + element_length = 1.0 + element = ThickElement( + name=element_name, + length=element_length, + ) + assert element.name == element_name + assert element.length == element_length + # Try to assign negative length and + # detect validation error without breaking pytest + element_length = -1.0 + passed = True + try: + element.length = element_length + except ValidationError as e: + print(e) + passed = False + assert not passed + + +def test_DriftElement(): + # Create one drift element with custom name and length + element_name = "drift_element" + element_length = 1.0 + element = DriftElement( + name=element_name, + length=element_length, + ) + assert element.name == element_name + assert element.length == element_length + # Try to assign negative length and + # detect validation error without breaking pytest + element_length = -1.0 + passed = True + try: + element.length = element_length + except ValidationError as e: + print(e) + passed = False + assert not passed + + +def test_QuadrupoleElement(): + # Create one drift element with custom name and length + element_name = "quadrupole_element" + element_length = 1.0 + element_magnetic_multipole_Bn1 = 1.1 + element_magnetic_multipole_Bn2 = 1.2 + element_magnetic_multipole_Bs1 = 2.1 + element_magnetic_multipole_Bs2 = 2.2 + element_magnetic_multipole_tilt1 = 3.1 + element_magnetic_multipole_tilt2 = 3.2 + element_magnetic_multipole = MagneticMultipoleParameters( + Bn1=element_magnetic_multipole_Bn1, + Bs1=element_magnetic_multipole_Bs1, + tilt1=element_magnetic_multipole_tilt1, + Bn2=element_magnetic_multipole_Bn2, + Bs2=element_magnetic_multipole_Bs2, + tilt2=element_magnetic_multipole_tilt2, + ) + element = QuadrupoleElement( + name=element_name, + length=element_length, + MagneticMultipoleP=element_magnetic_multipole, + ) + assert element.name == element_name + assert element.length == element_length + assert element.MagneticMultipoleP.Bn1 == element_magnetic_multipole_Bn1 + assert element.MagneticMultipoleP.Bs1 == element_magnetic_multipole_Bs1 + assert element.MagneticMultipoleP.tilt1 == element_magnetic_multipole_tilt1 + assert element.MagneticMultipoleP.Bn2 == element_magnetic_multipole_Bn2 + assert element.MagneticMultipoleP.Bs2 == element_magnetic_multipole_Bs2 + assert element.MagneticMultipoleP.tilt2 == element_magnetic_multipole_tilt2 + # Serialize the Line object to YAML + yaml_data = yaml.dump(element.model_dump(), default_flow_style=False) + print(f"\n{yaml_data}") + + +def test_Line(): + # Create first line with one base element + element1 = BaseElement(name="element1") + line1 = Line(line=[element1]) + assert line1.line == [element1] + # Extend first line with one thick element + element2 = ThickElement(name="element2", length=2.0) + line1.line.extend([element2]) + assert line1.line == [element1, element2] + # Create second line with one drift element + element3 = DriftElement(name="element3", length=3.0) + line2 = Line(line=[element3]) + # Extend first line with second line + line1.line.extend(line2.line) + assert line1.line == [element1, element2, element3] + + +def test_yaml(): + # Create one base element + element1 = BaseElement(name="element1") + # Create one thick element + element2 = ThickElement(name="element2", length=2.0) + # Create line with both elements + line = Line(line=[element1, element2]) + # Serialize the Line object to YAML + yaml_data = yaml.dump(line.model_dump(), default_flow_style=False) + print(f"\n{yaml_data}") + # Write the YAML data to a test file + test_file = "line.yaml" + with open(test_file, "w") as file: + file.write(yaml_data) + # Read the YAML data from the test file + with open(test_file, "r") as file: + yaml_data = yaml.safe_load(file) + # Parse the YAML data back into a Line object + loaded_line = Line(**yaml_data) + # Remove the test file + os.remove(test_file) + # Validate loaded Line object + assert line == loaded_line + + +def test_json(): + # Create one base element + element1 = BaseElement(name="element1") + # Create one thick element + element2 = ThickElement(name="element2", length=2.0) + # Create line with both elements + line = Line(line=[element1, element2]) + # Serialize the Line object to JSON + json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2) + print(f"\n{json_data}") + # Write the JSON data to a test file + test_file = "line.json" + with open(test_file, "w") as file: + file.write(json_data) + # Read the JSON data from the test file + with open(test_file, "r") as file: + json_data = json.loads(file.read()) + # Parse the JSON data back into a Line object + loaded_line = Line(**json_data) + # Remove the test file + os.remove(test_file) + # Validate loaded Line object + assert line == loaded_line