|
| 1 | +from pathlib import Path |
| 2 | +from typing import List, Union, Optional, Literal, Any, Annotated, Type, TypeVar |
| 3 | +from pydantic import BaseModel, Field, field_validator, field_serializer |
| 4 | +from pydantic import ValidationError |
| 5 | +import json |
| 6 | +import logging |
| 7 | + |
| 8 | +logger = logging.getLogger(__name__) |
| 9 | + |
| 10 | +INTERNAL_DEFAULT_HANDLE = "__result__" |
| 11 | +T = TypeVar("T", bound="PythonWorkflowDefinitionWorkflow") |
| 12 | + |
| 13 | +__all__ = ( |
| 14 | + "PythonWorkflowDefinitionInputNode", |
| 15 | + "PythonWorkflowDefinitionOutputNode", |
| 16 | + "PythonWorkflowDefinitionFunctionNode", |
| 17 | + "PythonWorkflowDefinitionEdge", |
| 18 | + "PythonWorkflowDefinitionWorkflow", |
| 19 | +) |
| 20 | + |
| 21 | + |
| 22 | +class PythonWorkflowDefinitionBaseNode(BaseModel): |
| 23 | + """Base model for all node types, containing common fields.""" |
| 24 | + |
| 25 | + id: int |
| 26 | + # The 'type' field will be overridden in subclasses with Literal types |
| 27 | + # to enable discriminated unions. |
| 28 | + type: str |
| 29 | + |
| 30 | + |
| 31 | +class PythonWorkflowDefinitionInputNode(PythonWorkflowDefinitionBaseNode): |
| 32 | + """Model for input nodes.""" |
| 33 | + |
| 34 | + type: Literal["input"] |
| 35 | + name: str |
| 36 | + value: Optional[Any] = None |
| 37 | + |
| 38 | + |
| 39 | +class PythonWorkflowDefinitionOutputNode(PythonWorkflowDefinitionBaseNode): |
| 40 | + """Model for output nodes.""" |
| 41 | + |
| 42 | + type: Literal["output"] |
| 43 | + name: str |
| 44 | + |
| 45 | + |
| 46 | +class PythonWorkflowDefinitionFunctionNode(PythonWorkflowDefinitionBaseNode): |
| 47 | + """ |
| 48 | + Model for function execution nodes. |
| 49 | + The 'name' attribute is computed automatically from 'value'. |
| 50 | + """ |
| 51 | + |
| 52 | + type: Literal["function"] |
| 53 | + value: str # Expected format: 'module.function' |
| 54 | + |
| 55 | + @field_validator("value") |
| 56 | + @classmethod |
| 57 | + def check_value_format(cls, v: str): |
| 58 | + if not v or "." not in v or v.startswith(".") or v.endswith("."): |
| 59 | + msg = ( |
| 60 | + "FunctionNode 'value' must be a non-empty string ", |
| 61 | + "in 'module.function' format with at least one period.", |
| 62 | + ) |
| 63 | + raise ValueError(msg) |
| 64 | + return v |
| 65 | + |
| 66 | + |
| 67 | +# Discriminated Union for Nodes |
| 68 | +PythonWorkflowDefinitionNode = Annotated[ |
| 69 | + Union[ |
| 70 | + PythonWorkflowDefinitionInputNode, |
| 71 | + PythonWorkflowDefinitionOutputNode, |
| 72 | + PythonWorkflowDefinitionFunctionNode, |
| 73 | + ], |
| 74 | + Field(discriminator="type"), |
| 75 | +] |
| 76 | + |
| 77 | + |
| 78 | +class PythonWorkflowDefinitionEdge(BaseModel): |
| 79 | + """Model for edges connecting nodes.""" |
| 80 | + |
| 81 | + target: int |
| 82 | + targetPort: Optional[str] = None |
| 83 | + source: int |
| 84 | + sourcePort: Optional[str] = None |
| 85 | + |
| 86 | + @field_validator("sourcePort", mode="before") |
| 87 | + @classmethod |
| 88 | + def handle_default_source(cls, v: Any) -> Optional[str]: |
| 89 | + """ |
| 90 | + Transforms incoming None/null for sourcePort to INTERNAL_DEFAULT_HANDLE. |
| 91 | + Runs before standard validation. |
| 92 | + """ |
| 93 | + # Allow not specifying the sourcePort -> null gets resolved to __result__ |
| 94 | + if v is None: |
| 95 | + return INTERNAL_DEFAULT_HANDLE |
| 96 | + elif v == INTERNAL_DEFAULT_HANDLE: |
| 97 | + # Disallow explicit use of the internal reserved handle name |
| 98 | + msg = ( |
| 99 | + f"Explicit use of reserved sourcePort '{INTERNAL_DEFAULT_HANDLE}' " |
| 100 | + f"is not allowed. Use null/None for default output." |
| 101 | + ) |
| 102 | + raise ValueError(msg) |
| 103 | + return v |
| 104 | + |
| 105 | + @field_serializer("sourcePort") |
| 106 | + def serialize_source_handle(self, v: Optional[str]) -> Optional[str]: |
| 107 | + """ |
| 108 | + SERIALIZATION (Output): Converts internal INTERNAL_DEFAULT_HANDLE ("__result__") |
| 109 | + back to None. |
| 110 | + """ |
| 111 | + if v == INTERNAL_DEFAULT_HANDLE: |
| 112 | + return None # Map "__result__" back to None for JSON output |
| 113 | + return v # Keep other handle names as they are |
| 114 | + |
| 115 | + |
| 116 | +class PythonWorkflowDefinitionWorkflow(BaseModel): |
| 117 | + """The main workflow model.""" |
| 118 | + |
| 119 | + nodes: List[PythonWorkflowDefinitionNode] |
| 120 | + edges: List[PythonWorkflowDefinitionEdge] |
| 121 | + |
| 122 | + def dump_json( |
| 123 | + self, |
| 124 | + *, |
| 125 | + indent: Optional[int] = 2, |
| 126 | + **kwargs, |
| 127 | + ) -> str: |
| 128 | + """ |
| 129 | + Dumps the workflow model to a JSON string. |
| 130 | +
|
| 131 | + Args: |
| 132 | + indent: JSON indentation level. |
| 133 | + exclude_computed_function_names: If True (default), excludes the computed |
| 134 | + 'name' field from FunctionNode objects |
| 135 | + in the output. |
| 136 | + **kwargs: Additional keyword arguments passed to Pydantic's model_dump. |
| 137 | +
|
| 138 | + Returns: |
| 139 | + JSON string representation of the workflow. |
| 140 | + """ |
| 141 | + |
| 142 | + # Dump the model to a dictionary first, using mode='json' for compatible types |
| 143 | + # Pass any extra kwargs (like custom 'exclude' rules for other fields) |
| 144 | + workflow_dict = self.model_dump(mode="json", **kwargs) |
| 145 | + |
| 146 | + # Dump the dictionary to a JSON string |
| 147 | + try: |
| 148 | + json_string = json.dumps(workflow_dict, indent=indent) |
| 149 | + logger.info("Successfully dumped workflow model to JSON string.") |
| 150 | + return json_string |
| 151 | + except TypeError as e: |
| 152 | + logger.error( |
| 153 | + f"Error serializing workflow dictionary to JSON: {e}", exc_info=True |
| 154 | + ) |
| 155 | + raise # Re-raise after logging |
| 156 | + |
| 157 | + def dump_json_file( |
| 158 | + self, |
| 159 | + file_name: Union[str, Path], |
| 160 | + *, |
| 161 | + indent: Optional[int] = 2, |
| 162 | + **kwargs, |
| 163 | + ) -> None: |
| 164 | + """ |
| 165 | + Dumps the workflow model to a JSON file. |
| 166 | +
|
| 167 | + Args: |
| 168 | + file_path: Path to the output JSON file. |
| 169 | + indent: JSON indentation level. |
| 170 | + exclude_computed_function_names: If True, excludes the computed 'name' field |
| 171 | + from FunctionNode objects. |
| 172 | + **kwargs: Additional keyword arguments passed to Pydantic's model_dump. |
| 173 | + """ |
| 174 | + logger.info(f"Dumping workflow model to JSON file: {file_name}") |
| 175 | + # Pass kwargs to dump_json, which passes them to model_dump |
| 176 | + json_string = self.dump_json( |
| 177 | + indent=indent, |
| 178 | + **kwargs, |
| 179 | + ) |
| 180 | + try: |
| 181 | + with open(file_name, "w", encoding="utf-8") as f: |
| 182 | + f.write(json_string) |
| 183 | + logger.info(f"Successfully wrote workflow model to {file_name}.") |
| 184 | + except IOError as e: |
| 185 | + logger.error( |
| 186 | + f"Error writing workflow model to file {file_name}: {e}", exc_info=True |
| 187 | + ) |
| 188 | + raise |
| 189 | + |
| 190 | + @classmethod |
| 191 | + def load_json_str(cls: Type[T], json_data: Union[str, bytes]) -> dict: |
| 192 | + """ |
| 193 | + Loads and validates workflow data from a JSON string or bytes. |
| 194 | +
|
| 195 | + Args: |
| 196 | + json_data: The JSON data as a string or bytes. |
| 197 | +
|
| 198 | + Returns: |
| 199 | + An instance of PwdWorkflow. |
| 200 | +
|
| 201 | + Raises: |
| 202 | + pydantic.ValidationError: If validation fails. |
| 203 | + json.JSONDecodeError: If json_data is not valid JSON. |
| 204 | + """ |
| 205 | + logger.info("Loading workflow model from JSON data...") |
| 206 | + try: |
| 207 | + # Pydantic v2 method handles bytes or str directly |
| 208 | + instance = cls.model_validate_json(json_data) |
| 209 | + # Pydantic v1 equivalent: instance = cls.parse_raw(json_data) |
| 210 | + logger.info( |
| 211 | + "Successfully loaded and validated workflow model from JSON data." |
| 212 | + ) |
| 213 | + return instance.model_dump() |
| 214 | + except ValidationError: # Catch validation errors specifically |
| 215 | + logger.error("Workflow model validation failed.", exc_info=True) |
| 216 | + raise |
| 217 | + except json.JSONDecodeError: # Catch JSON parsing errors specifically |
| 218 | + logger.error("Invalid JSON format encountered.", exc_info=True) |
| 219 | + raise |
| 220 | + except Exception as e: # Catch any other unexpected errors |
| 221 | + logger.error( |
| 222 | + f"An unexpected error occurred during JSON loading: {e}", exc_info=True |
| 223 | + ) |
| 224 | + raise |
| 225 | + |
| 226 | + @classmethod |
| 227 | + def load_json_file(cls: Type[T], file_name: Union[str, Path]) -> dict: |
| 228 | + """ |
| 229 | + Loads and validates workflow data from a JSON file. |
| 230 | +
|
| 231 | + Args: |
| 232 | + file_path: The path to the JSON file. |
| 233 | +
|
| 234 | + Returns: |
| 235 | + An instance of PwdWorkflow. |
| 236 | +
|
| 237 | + Raises: |
| 238 | + FileNotFoundError: If the file is not found. |
| 239 | + pydantic.ValidationError: If validation fails. |
| 240 | + json.JSONDecodeError: If the file is not valid JSON. |
| 241 | + IOError: If there are other file reading issues. |
| 242 | + """ |
| 243 | + logger.info(f"Loading workflow model from JSON file: {file_name}") |
| 244 | + try: |
| 245 | + file_content = Path(file_name).read_text(encoding="utf-8") |
| 246 | + # Delegate validation to the string loading method |
| 247 | + return cls.load_json_str(file_content) |
| 248 | + except FileNotFoundError: |
| 249 | + logger.error(f"JSON file not found: {file_name}", exc_info=True) |
| 250 | + raise |
| 251 | + except IOError as e: |
| 252 | + logger.error(f"Error reading JSON file {file_name}: {e}", exc_info=True) |
| 253 | + raise |
0 commit comments