diff --git a/binder/environment.yml b/binder/environment.yml index 0fad6cb..67e2a37 100644 --- a/binder/environment.yml +++ b/binder/environment.yml @@ -2,6 +2,7 @@ channels: - conda-forge dependencies: - python =3.12 +- pydantic =2.11.4 - hatch =1.14.1 - hatchling =1.27.0 - httpcore =1.0.7 diff --git a/python_workflow_definition/pyproject.toml b/python_workflow_definition/pyproject.toml index 3f12642..8a50926 100644 --- a/python_workflow_definition/pyproject.toml +++ b/python_workflow_definition/pyproject.toml @@ -16,10 +16,11 @@ authors = [ ] license = { file = "../LICENSE" } dependencies = [ - "aiida-workgraph>=0.5.1,<=0.5.2", - "numpy>=1.21,<2", - "jobflow>=0.1.18,<=0.1.19", - "pyiron_base>=0.11.10,<=0.11.11", + "aiida-workgraph>=0.5.1,<=0.5.2", + "numpy>=1.21,<2", + "jobflow>=0.1.18,<=0.1.19", + "pyiron_base>=0.11.10,<=0.11.11", + "pydantic>=2.7.0,<=2.11.4", ] [project.optional-dependencies] diff --git a/python_workflow_definition/src/python_workflow_definition/aiida.py b/python_workflow_definition/src/python_workflow_definition/aiida.py index d16be48..6a917c1 100644 --- a/python_workflow_definition/src/python_workflow_definition/aiida.py +++ b/python_workflow_definition/src/python_workflow_definition/aiida.py @@ -1,5 +1,4 @@ from importlib import import_module -import json import traceback from aiida import orm @@ -7,6 +6,7 @@ from aiida_workgraph import WorkGraph, task from aiida_workgraph.socket import TaskSocketNamespace +from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow from python_workflow_definition.shared import ( convert_nodes_list_to_dict, update_node_names, @@ -22,8 +22,11 @@ def load_workflow_json(file_name: str) -> WorkGraph: - with open(file_name) as f: - data = remove_result(workflow_dict=json.load(f)) + data = remove_result( + workflow_dict=PythonWorkflowDefinitionWorkflow.load_json_file( + file_name=file_name + ) + ) wg = WorkGraph() task_name_mapping = {} @@ -136,12 +139,6 @@ def write_workflow_json(wg: WorkGraph, file_name: str) -> dict: SOURCE_PORT_LABEL: None, } ) - with open(file_name, "w") as f: - # json.dump({"nodes": data[], "edges": edges_new_lst}, f) - json.dump( - set_result_node(workflow_dict=update_node_names(workflow_dict=data)), - f, - indent=2, - ) - - return data + PythonWorkflowDefinitionWorkflow( + **set_result_node(workflow_dict=update_node_names(workflow_dict=data)) + ).dump_json_file(file_name=file_name, indent=2) diff --git a/python_workflow_definition/src/python_workflow_definition/executorlib.py b/python_workflow_definition/src/python_workflow_definition/executorlib.py index 8d08b4c..807dfcf 100644 --- a/python_workflow_definition/src/python_workflow_definition/executorlib.py +++ b/python_workflow_definition/src/python_workflow_definition/executorlib.py @@ -1,9 +1,9 @@ from concurrent.futures import Executor from importlib import import_module from inspect import isfunction -import json +from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow from python_workflow_definition.shared import ( get_dict, get_list, @@ -38,8 +38,11 @@ def _get_value(result_dict: dict, nodes_new_dict: dict, link_dict: dict, exe: Ex def load_workflow_json(file_name: str, exe: Executor): - with open(file_name, "r") as f: - content = remove_result(workflow_dict=json.load(f)) + content = remove_result( + workflow_dict=PythonWorkflowDefinitionWorkflow.load_json_file( + file_name=file_name + ) + ) edges_new_lst = content[EDGES_LABEL] nodes_new_dict = {} diff --git a/python_workflow_definition/src/python_workflow_definition/jobflow.py b/python_workflow_definition/src/python_workflow_definition/jobflow.py index c4a14af..a0fb140 100644 --- a/python_workflow_definition/src/python_workflow_definition/jobflow.py +++ b/python_workflow_definition/src/python_workflow_definition/jobflow.py @@ -1,10 +1,10 @@ -import json from importlib import import_module from inspect import isfunction import numpy as np from jobflow import job, Flow +from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow from python_workflow_definition.shared import ( get_dict, get_list, @@ -272,8 +272,11 @@ def _get_item_from_tuple(input_obj, index, index_lst): def load_workflow_json(file_name: str) -> Flow: - with open(file_name, "r") as f: - content = remove_result(workflow_dict=json.load(f)) + content = remove_result( + workflow_dict=PythonWorkflowDefinitionWorkflow.load_json_file( + file_name=file_name + ) + ) edges_new_lst = [] for edge in content[EDGES_LABEL]: @@ -332,13 +335,13 @@ def write_workflow_json(flow: Flow, file_name: str = "workflow.json"): else: nodes_store_lst.append({"id": k, "type": "input", "value": v}) - with open(file_name, "w") as f: - json.dump( - set_result_node( - workflow_dict=update_node_names( - workflow_dict={NODES_LABEL: nodes_store_lst, EDGES_LABEL: edges_lst} - ) - ), - f, - indent=2, + PythonWorkflowDefinitionWorkflow( + **set_result_node( + workflow_dict=update_node_names( + workflow_dict={ + NODES_LABEL: nodes_store_lst, + EDGES_LABEL: edges_lst, + } + ) ) + ).dump_json_file(file_name=file_name, indent=2) diff --git a/python_workflow_definition/src/python_workflow_definition/models.py b/python_workflow_definition/src/python_workflow_definition/models.py new file mode 100644 index 0000000..2e196ef --- /dev/null +++ b/python_workflow_definition/src/python_workflow_definition/models.py @@ -0,0 +1,253 @@ +from pathlib import Path +from typing import List, Union, Optional, Literal, Any, Annotated, Type, TypeVar +from pydantic import BaseModel, Field, field_validator, field_serializer +from pydantic import ValidationError +import json +import logging + +logger = logging.getLogger(__name__) + +INTERNAL_DEFAULT_HANDLE = "__result__" +T = TypeVar("T", bound="PythonWorkflowDefinitionWorkflow") + +__all__ = ( + "PythonWorkflowDefinitionInputNode", + "PythonWorkflowDefinitionOutputNode", + "PythonWorkflowDefinitionFunctionNode", + "PythonWorkflowDefinitionEdge", + "PythonWorkflowDefinitionWorkflow", +) + + +class PythonWorkflowDefinitionBaseNode(BaseModel): + """Base model for all node types, containing common fields.""" + + id: int + # The 'type' field will be overridden in subclasses with Literal types + # to enable discriminated unions. + type: str + + +class PythonWorkflowDefinitionInputNode(PythonWorkflowDefinitionBaseNode): + """Model for input nodes.""" + + type: Literal["input"] + name: str + value: Optional[Any] = None + + +class PythonWorkflowDefinitionOutputNode(PythonWorkflowDefinitionBaseNode): + """Model for output nodes.""" + + type: Literal["output"] + name: str + + +class PythonWorkflowDefinitionFunctionNode(PythonWorkflowDefinitionBaseNode): + """ + Model for function execution nodes. + The 'name' attribute is computed automatically from 'value'. + """ + + type: Literal["function"] + value: str # Expected format: 'module.function' + + @field_validator("value") + @classmethod + def check_value_format(cls, v: str): + if not v or "." not in v or v.startswith(".") or v.endswith("."): + msg = ( + "FunctionNode 'value' must be a non-empty string ", + "in 'module.function' format with at least one period.", + ) + raise ValueError(msg) + return v + + +# Discriminated Union for Nodes +PythonWorkflowDefinitionNode = Annotated[ + Union[ + PythonWorkflowDefinitionInputNode, + PythonWorkflowDefinitionOutputNode, + PythonWorkflowDefinitionFunctionNode, + ], + Field(discriminator="type"), +] + + +class PythonWorkflowDefinitionEdge(BaseModel): + """Model for edges connecting nodes.""" + + target: int + targetPort: Optional[str] = None + source: int + sourcePort: Optional[str] = None + + @field_validator("sourcePort", mode="before") + @classmethod + def handle_default_source(cls, v: Any) -> Optional[str]: + """ + Transforms incoming None/null for sourcePort to INTERNAL_DEFAULT_HANDLE. + Runs before standard validation. + """ + # Allow not specifying the sourcePort -> null gets resolved to __result__ + if v is None: + return INTERNAL_DEFAULT_HANDLE + elif v == INTERNAL_DEFAULT_HANDLE: + # Disallow explicit use of the internal reserved handle name + msg = ( + f"Explicit use of reserved sourcePort '{INTERNAL_DEFAULT_HANDLE}' " + f"is not allowed. Use null/None for default output." + ) + raise ValueError(msg) + return v + + @field_serializer("sourcePort") + def serialize_source_handle(self, v: Optional[str]) -> Optional[str]: + """ + SERIALIZATION (Output): Converts internal INTERNAL_DEFAULT_HANDLE ("__result__") + back to None. + """ + if v == INTERNAL_DEFAULT_HANDLE: + return None # Map "__result__" back to None for JSON output + return v # Keep other handle names as they are + + +class PythonWorkflowDefinitionWorkflow(BaseModel): + """The main workflow model.""" + + nodes: List[PythonWorkflowDefinitionNode] + edges: List[PythonWorkflowDefinitionEdge] + + def dump_json( + self, + *, + indent: Optional[int] = 2, + **kwargs, + ) -> str: + """ + Dumps the workflow model to a JSON string. + + Args: + indent: JSON indentation level. + exclude_computed_function_names: If True (default), excludes the computed + 'name' field from FunctionNode objects + in the output. + **kwargs: Additional keyword arguments passed to Pydantic's model_dump. + + Returns: + JSON string representation of the workflow. + """ + + # Dump the model to a dictionary first, using mode='json' for compatible types + # Pass any extra kwargs (like custom 'exclude' rules for other fields) + workflow_dict = self.model_dump(mode="json", **kwargs) + + # Dump the dictionary to a JSON string + try: + json_string = json.dumps(workflow_dict, indent=indent) + logger.info("Successfully dumped workflow model to JSON string.") + return json_string + except TypeError as e: + logger.error( + f"Error serializing workflow dictionary to JSON: {e}", exc_info=True + ) + raise # Re-raise after logging + + def dump_json_file( + self, + file_name: Union[str, Path], + *, + indent: Optional[int] = 2, + **kwargs, + ) -> None: + """ + Dumps the workflow model to a JSON file. + + Args: + file_path: Path to the output JSON file. + indent: JSON indentation level. + exclude_computed_function_names: If True, excludes the computed 'name' field + from FunctionNode objects. + **kwargs: Additional keyword arguments passed to Pydantic's model_dump. + """ + logger.info(f"Dumping workflow model to JSON file: {file_name}") + # Pass kwargs to dump_json, which passes them to model_dump + json_string = self.dump_json( + indent=indent, + **kwargs, + ) + try: + with open(file_name, "w", encoding="utf-8") as f: + f.write(json_string) + logger.info(f"Successfully wrote workflow model to {file_name}.") + except IOError as e: + logger.error( + f"Error writing workflow model to file {file_name}: {e}", exc_info=True + ) + raise + + @classmethod + def load_json_str(cls: Type[T], json_data: Union[str, bytes]) -> dict: + """ + Loads and validates workflow data from a JSON string or bytes. + + Args: + json_data: The JSON data as a string or bytes. + + Returns: + An instance of PwdWorkflow. + + Raises: + pydantic.ValidationError: If validation fails. + json.JSONDecodeError: If json_data is not valid JSON. + """ + logger.info("Loading workflow model from JSON data...") + try: + # Pydantic v2 method handles bytes or str directly + instance = cls.model_validate_json(json_data) + # Pydantic v1 equivalent: instance = cls.parse_raw(json_data) + logger.info( + "Successfully loaded and validated workflow model from JSON data." + ) + return instance.model_dump() + except ValidationError: # Catch validation errors specifically + logger.error("Workflow model validation failed.", exc_info=True) + raise + except json.JSONDecodeError: # Catch JSON parsing errors specifically + logger.error("Invalid JSON format encountered.", exc_info=True) + raise + except Exception as e: # Catch any other unexpected errors + logger.error( + f"An unexpected error occurred during JSON loading: {e}", exc_info=True + ) + raise + + @classmethod + def load_json_file(cls: Type[T], file_name: Union[str, Path]) -> dict: + """ + Loads and validates workflow data from a JSON file. + + Args: + file_path: The path to the JSON file. + + Returns: + An instance of PwdWorkflow. + + Raises: + FileNotFoundError: If the file is not found. + pydantic.ValidationError: If validation fails. + json.JSONDecodeError: If the file is not valid JSON. + IOError: If there are other file reading issues. + """ + logger.info(f"Loading workflow model from JSON file: {file_name}") + try: + file_content = Path(file_name).read_text(encoding="utf-8") + # Delegate validation to the string loading method + return cls.load_json_str(file_content) + except FileNotFoundError: + logger.error(f"JSON file not found: {file_name}", exc_info=True) + raise + except IOError as e: + logger.error(f"Error reading JSON file {file_name}: {e}", exc_info=True) + raise diff --git a/python_workflow_definition/src/python_workflow_definition/plot.py b/python_workflow_definition/src/python_workflow_definition/plot.py index c4b166a..b3d4953 100644 --- a/python_workflow_definition/src/python_workflow_definition/plot.py +++ b/python_workflow_definition/src/python_workflow_definition/plot.py @@ -1,9 +1,8 @@ -import json - from IPython.display import SVG, display import networkx as nx +from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow from python_workflow_definition.purepython import group_edges from python_workflow_definition.shared import ( get_kwargs, @@ -16,8 +15,7 @@ def plot(file_name: str): - with open(file_name, "r") as f: - content = json.load(f) + content = PythonWorkflowDefinitionWorkflow.load_json_file(file_name=file_name) graph = nx.DiGraph() node_dict = convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]) diff --git a/python_workflow_definition/src/python_workflow_definition/purepython.py b/python_workflow_definition/src/python_workflow_definition/purepython.py index c697b08..3fec845 100644 --- a/python_workflow_definition/src/python_workflow_definition/purepython.py +++ b/python_workflow_definition/src/python_workflow_definition/purepython.py @@ -1,8 +1,8 @@ -import json from importlib import import_module from inspect import isfunction +from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow from python_workflow_definition.shared import ( get_dict, get_list, @@ -67,8 +67,11 @@ def _get_value(result_dict: dict, nodes_new_dict: dict, link_dict: dict): def load_workflow_json(file_name: str): - with open(file_name, "r") as f: - content = remove_result(workflow_dict=json.load(f)) + content = remove_result( + workflow_dict=PythonWorkflowDefinitionWorkflow.load_json_file( + file_name=file_name + ) + ) edges_new_lst = content[EDGES_LABEL] nodes_new_dict = {} diff --git a/python_workflow_definition/src/python_workflow_definition/pyiron_base.py b/python_workflow_definition/src/python_workflow_definition/pyiron_base.py index 29335cd..8785553 100644 --- a/python_workflow_definition/src/python_workflow_definition/pyiron_base.py +++ b/python_workflow_definition/src/python_workflow_definition/pyiron_base.py @@ -1,12 +1,12 @@ from importlib import import_module from inspect import isfunction -import json from typing import Optional import numpy as np from pyiron_base import job, Project from pyiron_base.project.delayed import DelayedObject +from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow from python_workflow_definition.shared import ( get_kwargs, get_source_handles, @@ -230,8 +230,11 @@ def load_workflow_json(file_name: str, project: Optional[Project] = None): if project is None: project = Project(".") - with open(file_name, "r") as f: - content = remove_result(workflow_dict=json.load(f)) + content = remove_result( + workflow_dict=PythonWorkflowDefinitionWorkflow.load_json_file( + file_name=file_name + ) + ) edges_new_lst = content[EDGES_LABEL] nodes_new_dict = {} @@ -293,16 +296,13 @@ def write_workflow_json( else: nodes_store_lst.append({"id": k, "type": "input", "value": v}) - with open(file_name, "w") as f: - json.dump( - set_result_node( - workflow_dict=update_node_names( - workflow_dict={ - NODES_LABEL: nodes_store_lst, - EDGES_LABEL: edges_new_lst, - } - ) - ), - f, - indent=2, + PythonWorkflowDefinitionWorkflow( + **set_result_node( + workflow_dict=update_node_names( + workflow_dict={ + NODES_LABEL: nodes_store_lst, + EDGES_LABEL: edges_new_lst, + } + ) ) + ).dump_json_file(file_name=file_name, indent=2)