Skip to content

Commit e0c726f

Browse files
authored
Add pydantic validation (#97)
* Add pydantic validation * remove unused import * black formatting * return dict rather than pydantic model * different dictionary conversion * nested conversion * fix dictionary conversion * use model dump again
1 parent cd022f5 commit e0c726f

File tree

9 files changed

+312
-53
lines changed

9 files changed

+312
-53
lines changed

binder/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ channels:
22
- conda-forge
33
dependencies:
44
- python =3.12
5+
- pydantic =2.11.4
56
- hatch =1.14.1
67
- hatchling =1.27.0
78
- httpcore =1.0.7

python_workflow_definition/pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ authors = [
1616
]
1717
license = { file = "../LICENSE" }
1818
dependencies = [
19-
"aiida-workgraph>=0.5.1,<=0.5.2",
20-
"numpy>=1.21,<2",
21-
"jobflow>=0.1.18,<=0.1.19",
22-
"pyiron_base>=0.11.10,<=0.11.11",
19+
"aiida-workgraph>=0.5.1,<=0.5.2",
20+
"numpy>=1.21,<2",
21+
"jobflow>=0.1.18,<=0.1.19",
22+
"pyiron_base>=0.11.10,<=0.11.11",
23+
"pydantic>=2.7.0,<=2.11.4",
2324
]
2425

2526
[project.optional-dependencies]

python_workflow_definition/src/python_workflow_definition/aiida.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from importlib import import_module
2-
import json
32
import traceback
43

54
from aiida import orm
65
from aiida_pythonjob.data.serializer import general_serializer
76
from aiida_workgraph import WorkGraph, task
87
from aiida_workgraph.socket import TaskSocketNamespace
98

9+
from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow
1010
from python_workflow_definition.shared import (
1111
convert_nodes_list_to_dict,
1212
update_node_names,
@@ -22,8 +22,11 @@
2222

2323

2424
def load_workflow_json(file_name: str) -> WorkGraph:
25-
with open(file_name) as f:
26-
data = remove_result(workflow_dict=json.load(f))
25+
data = remove_result(
26+
workflow_dict=PythonWorkflowDefinitionWorkflow.load_json_file(
27+
file_name=file_name
28+
)
29+
)
2730

2831
wg = WorkGraph()
2932
task_name_mapping = {}
@@ -136,12 +139,6 @@ def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
136139
SOURCE_PORT_LABEL: None,
137140
}
138141
)
139-
with open(file_name, "w") as f:
140-
# json.dump({"nodes": data[], "edges": edges_new_lst}, f)
141-
json.dump(
142-
set_result_node(workflow_dict=update_node_names(workflow_dict=data)),
143-
f,
144-
indent=2,
145-
)
146-
147-
return data
142+
PythonWorkflowDefinitionWorkflow(
143+
**set_result_node(workflow_dict=update_node_names(workflow_dict=data))
144+
).dump_json_file(file_name=file_name, indent=2)

python_workflow_definition/src/python_workflow_definition/executorlib.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from concurrent.futures import Executor
22
from importlib import import_module
33
from inspect import isfunction
4-
import json
54

65

6+
from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow
77
from python_workflow_definition.shared import (
88
get_dict,
99
get_list,
@@ -38,8 +38,11 @@ def _get_value(result_dict: dict, nodes_new_dict: dict, link_dict: dict, exe: Ex
3838

3939

4040
def load_workflow_json(file_name: str, exe: Executor):
41-
with open(file_name, "r") as f:
42-
content = remove_result(workflow_dict=json.load(f))
41+
content = remove_result(
42+
workflow_dict=PythonWorkflowDefinitionWorkflow.load_json_file(
43+
file_name=file_name
44+
)
45+
)
4346

4447
edges_new_lst = content[EDGES_LABEL]
4548
nodes_new_dict = {}

python_workflow_definition/src/python_workflow_definition/jobflow.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import json
21
from importlib import import_module
32
from inspect import isfunction
43

54
import numpy as np
65
from jobflow import job, Flow
76

7+
from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow
88
from python_workflow_definition.shared import (
99
get_dict,
1010
get_list,
@@ -272,8 +272,11 @@ def _get_item_from_tuple(input_obj, index, index_lst):
272272

273273

274274
def load_workflow_json(file_name: str) -> Flow:
275-
with open(file_name, "r") as f:
276-
content = remove_result(workflow_dict=json.load(f))
275+
content = remove_result(
276+
workflow_dict=PythonWorkflowDefinitionWorkflow.load_json_file(
277+
file_name=file_name
278+
)
279+
)
277280

278281
edges_new_lst = []
279282
for edge in content[EDGES_LABEL]:
@@ -332,13 +335,13 @@ def write_workflow_json(flow: Flow, file_name: str = "workflow.json"):
332335
else:
333336
nodes_store_lst.append({"id": k, "type": "input", "value": v})
334337

335-
with open(file_name, "w") as f:
336-
json.dump(
337-
set_result_node(
338-
workflow_dict=update_node_names(
339-
workflow_dict={NODES_LABEL: nodes_store_lst, EDGES_LABEL: edges_lst}
340-
)
341-
),
342-
f,
343-
indent=2,
338+
PythonWorkflowDefinitionWorkflow(
339+
**set_result_node(
340+
workflow_dict=update_node_names(
341+
workflow_dict={
342+
NODES_LABEL: nodes_store_lst,
343+
EDGES_LABEL: edges_lst,
344+
}
345+
)
344346
)
347+
).dump_json_file(file_name=file_name, indent=2)
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
from pathlib import Path
2+
from typing import List, Union, Optional, Literal, Any, Annotated, Type, TypeVar
3+
from pydantic import BaseModel, Field, field_validator, field_serializer
4+
from pydantic import ValidationError
5+
import json
6+
import logging
7+
8+
logger = logging.getLogger(__name__)
9+
10+
INTERNAL_DEFAULT_HANDLE = "__result__"
11+
T = TypeVar("T", bound="PythonWorkflowDefinitionWorkflow")
12+
13+
__all__ = (
14+
"PythonWorkflowDefinitionInputNode",
15+
"PythonWorkflowDefinitionOutputNode",
16+
"PythonWorkflowDefinitionFunctionNode",
17+
"PythonWorkflowDefinitionEdge",
18+
"PythonWorkflowDefinitionWorkflow",
19+
)
20+
21+
22+
class PythonWorkflowDefinitionBaseNode(BaseModel):
23+
"""Base model for all node types, containing common fields."""
24+
25+
id: int
26+
# The 'type' field will be overridden in subclasses with Literal types
27+
# to enable discriminated unions.
28+
type: str
29+
30+
31+
class PythonWorkflowDefinitionInputNode(PythonWorkflowDefinitionBaseNode):
32+
"""Model for input nodes."""
33+
34+
type: Literal["input"]
35+
name: str
36+
value: Optional[Any] = None
37+
38+
39+
class PythonWorkflowDefinitionOutputNode(PythonWorkflowDefinitionBaseNode):
40+
"""Model for output nodes."""
41+
42+
type: Literal["output"]
43+
name: str
44+
45+
46+
class PythonWorkflowDefinitionFunctionNode(PythonWorkflowDefinitionBaseNode):
47+
"""
48+
Model for function execution nodes.
49+
The 'name' attribute is computed automatically from 'value'.
50+
"""
51+
52+
type: Literal["function"]
53+
value: str # Expected format: 'module.function'
54+
55+
@field_validator("value")
56+
@classmethod
57+
def check_value_format(cls, v: str):
58+
if not v or "." not in v or v.startswith(".") or v.endswith("."):
59+
msg = (
60+
"FunctionNode 'value' must be a non-empty string ",
61+
"in 'module.function' format with at least one period.",
62+
)
63+
raise ValueError(msg)
64+
return v
65+
66+
67+
# Discriminated Union for Nodes
68+
PythonWorkflowDefinitionNode = Annotated[
69+
Union[
70+
PythonWorkflowDefinitionInputNode,
71+
PythonWorkflowDefinitionOutputNode,
72+
PythonWorkflowDefinitionFunctionNode,
73+
],
74+
Field(discriminator="type"),
75+
]
76+
77+
78+
class PythonWorkflowDefinitionEdge(BaseModel):
79+
"""Model for edges connecting nodes."""
80+
81+
target: int
82+
targetPort: Optional[str] = None
83+
source: int
84+
sourcePort: Optional[str] = None
85+
86+
@field_validator("sourcePort", mode="before")
87+
@classmethod
88+
def handle_default_source(cls, v: Any) -> Optional[str]:
89+
"""
90+
Transforms incoming None/null for sourcePort to INTERNAL_DEFAULT_HANDLE.
91+
Runs before standard validation.
92+
"""
93+
# Allow not specifying the sourcePort -> null gets resolved to __result__
94+
if v is None:
95+
return INTERNAL_DEFAULT_HANDLE
96+
elif v == INTERNAL_DEFAULT_HANDLE:
97+
# Disallow explicit use of the internal reserved handle name
98+
msg = (
99+
f"Explicit use of reserved sourcePort '{INTERNAL_DEFAULT_HANDLE}' "
100+
f"is not allowed. Use null/None for default output."
101+
)
102+
raise ValueError(msg)
103+
return v
104+
105+
@field_serializer("sourcePort")
106+
def serialize_source_handle(self, v: Optional[str]) -> Optional[str]:
107+
"""
108+
SERIALIZATION (Output): Converts internal INTERNAL_DEFAULT_HANDLE ("__result__")
109+
back to None.
110+
"""
111+
if v == INTERNAL_DEFAULT_HANDLE:
112+
return None # Map "__result__" back to None for JSON output
113+
return v # Keep other handle names as they are
114+
115+
116+
class PythonWorkflowDefinitionWorkflow(BaseModel):
117+
"""The main workflow model."""
118+
119+
nodes: List[PythonWorkflowDefinitionNode]
120+
edges: List[PythonWorkflowDefinitionEdge]
121+
122+
def dump_json(
123+
self,
124+
*,
125+
indent: Optional[int] = 2,
126+
**kwargs,
127+
) -> str:
128+
"""
129+
Dumps the workflow model to a JSON string.
130+
131+
Args:
132+
indent: JSON indentation level.
133+
exclude_computed_function_names: If True (default), excludes the computed
134+
'name' field from FunctionNode objects
135+
in the output.
136+
**kwargs: Additional keyword arguments passed to Pydantic's model_dump.
137+
138+
Returns:
139+
JSON string representation of the workflow.
140+
"""
141+
142+
# Dump the model to a dictionary first, using mode='json' for compatible types
143+
# Pass any extra kwargs (like custom 'exclude' rules for other fields)
144+
workflow_dict = self.model_dump(mode="json", **kwargs)
145+
146+
# Dump the dictionary to a JSON string
147+
try:
148+
json_string = json.dumps(workflow_dict, indent=indent)
149+
logger.info("Successfully dumped workflow model to JSON string.")
150+
return json_string
151+
except TypeError as e:
152+
logger.error(
153+
f"Error serializing workflow dictionary to JSON: {e}", exc_info=True
154+
)
155+
raise # Re-raise after logging
156+
157+
def dump_json_file(
158+
self,
159+
file_name: Union[str, Path],
160+
*,
161+
indent: Optional[int] = 2,
162+
**kwargs,
163+
) -> None:
164+
"""
165+
Dumps the workflow model to a JSON file.
166+
167+
Args:
168+
file_path: Path to the output JSON file.
169+
indent: JSON indentation level.
170+
exclude_computed_function_names: If True, excludes the computed 'name' field
171+
from FunctionNode objects.
172+
**kwargs: Additional keyword arguments passed to Pydantic's model_dump.
173+
"""
174+
logger.info(f"Dumping workflow model to JSON file: {file_name}")
175+
# Pass kwargs to dump_json, which passes them to model_dump
176+
json_string = self.dump_json(
177+
indent=indent,
178+
**kwargs,
179+
)
180+
try:
181+
with open(file_name, "w", encoding="utf-8") as f:
182+
f.write(json_string)
183+
logger.info(f"Successfully wrote workflow model to {file_name}.")
184+
except IOError as e:
185+
logger.error(
186+
f"Error writing workflow model to file {file_name}: {e}", exc_info=True
187+
)
188+
raise
189+
190+
@classmethod
191+
def load_json_str(cls: Type[T], json_data: Union[str, bytes]) -> dict:
192+
"""
193+
Loads and validates workflow data from a JSON string or bytes.
194+
195+
Args:
196+
json_data: The JSON data as a string or bytes.
197+
198+
Returns:
199+
An instance of PwdWorkflow.
200+
201+
Raises:
202+
pydantic.ValidationError: If validation fails.
203+
json.JSONDecodeError: If json_data is not valid JSON.
204+
"""
205+
logger.info("Loading workflow model from JSON data...")
206+
try:
207+
# Pydantic v2 method handles bytes or str directly
208+
instance = cls.model_validate_json(json_data)
209+
# Pydantic v1 equivalent: instance = cls.parse_raw(json_data)
210+
logger.info(
211+
"Successfully loaded and validated workflow model from JSON data."
212+
)
213+
return instance.model_dump()
214+
except ValidationError: # Catch validation errors specifically
215+
logger.error("Workflow model validation failed.", exc_info=True)
216+
raise
217+
except json.JSONDecodeError: # Catch JSON parsing errors specifically
218+
logger.error("Invalid JSON format encountered.", exc_info=True)
219+
raise
220+
except Exception as e: # Catch any other unexpected errors
221+
logger.error(
222+
f"An unexpected error occurred during JSON loading: {e}", exc_info=True
223+
)
224+
raise
225+
226+
@classmethod
227+
def load_json_file(cls: Type[T], file_name: Union[str, Path]) -> dict:
228+
"""
229+
Loads and validates workflow data from a JSON file.
230+
231+
Args:
232+
file_path: The path to the JSON file.
233+
234+
Returns:
235+
An instance of PwdWorkflow.
236+
237+
Raises:
238+
FileNotFoundError: If the file is not found.
239+
pydantic.ValidationError: If validation fails.
240+
json.JSONDecodeError: If the file is not valid JSON.
241+
IOError: If there are other file reading issues.
242+
"""
243+
logger.info(f"Loading workflow model from JSON file: {file_name}")
244+
try:
245+
file_content = Path(file_name).read_text(encoding="utf-8")
246+
# Delegate validation to the string loading method
247+
return cls.load_json_str(file_content)
248+
except FileNotFoundError:
249+
logger.error(f"JSON file not found: {file_name}", exc_info=True)
250+
raise
251+
except IOError as e:
252+
logger.error(f"Error reading JSON file {file_name}: {e}", exc_info=True)
253+
raise

0 commit comments

Comments
 (0)