Skip to content

Commit 9a3e4a8

Browse files
committed
Add pydantic models
1 parent d3ae576 commit 9a3e4a8

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

python_workflow_definition/src/python_workflow_definition/_pydantic/__init__.py

Whitespace-only changes.
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import List, Union, Optional, Literal, Any, Annotated
2+
from pydantic import BaseModel, Field, field_validator, computed_field
3+
import logging
4+
5+
logger = logging.getLogger(__name__)
6+
7+
INTERNAL_DEFAULT_HANDLE = "__result__"
8+
9+
10+
class BaseNode(BaseModel):
11+
"""Base model for all node types, containing common fields."""
12+
13+
id: int
14+
# The 'type' field will be overridden in subclasses with Literal types
15+
# to enable discriminated unions.
16+
type: str
17+
18+
19+
class InputNode(BaseNode):
20+
"""Model for input nodes."""
21+
22+
type: Literal["input"]
23+
name: str
24+
value: Optional[Any] = None
25+
26+
27+
class OutputNode(BaseNode):
28+
"""Model for output nodes."""
29+
30+
type: Literal["output"]
31+
name: str
32+
33+
34+
class FunctionNode(BaseNode):
35+
"""
36+
Model for function execution nodes.
37+
The 'name' attribute is computed automatically from 'value'.
38+
"""
39+
40+
type: Literal["function"]
41+
value: str # Expected format: 'module.function'
42+
43+
@field_validator("value")
44+
@classmethod
45+
def check_value_format(cls, v: str):
46+
if not v or "." not in v or v.startswith(".") or v.endswith("."):
47+
msg = (
48+
"FunctionNode 'value' must be a non-empty string ",
49+
"in 'module.function' format with at least one period.",
50+
)
51+
raise ValueError(msg)
52+
return v
53+
54+
@computed_field
55+
@property
56+
def name(self) -> str:
57+
"""Dynamically computes the function name from the 'value' field."""
58+
try:
59+
return self.value.rsplit(".", 1)[-1]
60+
except IndexError:
61+
msg = (
62+
f"Could not extract function name from invalid FunctionNode value: '{self.value}' (ID: {self.id}). Check validation."
63+
)
64+
raise ValueError(msg)
65+
except Exception as e:
66+
logger.error(
67+
f"Unexpected error computing name for FunctionNode value: '{self.value}' (ID: {self.id}): {e}",
68+
exc_info=True,
69+
)
70+
raise
71+
72+
73+
# Discriminated Union for Nodes
74+
Node = Annotated[
75+
Union[InputNode, OutputNode, FunctionNode], Field(discriminator="type")
76+
]
77+
78+
79+
class Edge(BaseModel):
80+
"""Model for edges connecting nodes."""
81+
82+
target: Optional[int] = None
83+
targetHandle: Optional[str] = None
84+
source: Optional[int] = None
85+
sourceHandle: Optional[str] = None
86+
87+
@field_validator("sourceHandle", mode="before")
88+
@classmethod
89+
def handle_default_source(cls, v: Any) -> Optional[str]:
90+
"""
91+
Transforms incoming None/null for sourceHandle to INTERNAL_DEFAULT_HANDLE.
92+
Runs before standard validation.
93+
"""
94+
# Allow not specifying the sourceHandle -> null gets resolved to __result__
95+
if v is None:
96+
return INTERNAL_DEFAULT_HANDLE
97+
elif v == INTERNAL_DEFAULT_HANDLE:
98+
# Disallow explicit use of the internal reserved handle name
99+
msg = (
100+
f"Explicit use of reserved sourceHandle '{INTERNAL_DEFAULT_HANDLE}' "
101+
f"is not allowed. Use null/None for default output."
102+
)
103+
raise ValueError(msg)
104+
return v
105+
106+
107+
class Workflow(BaseModel):
108+
"""The main workflow model."""
109+
110+
nodes: List[Node]
111+
edges: List[Edge]

0 commit comments

Comments
 (0)