Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable validation in pf.flow._save #2622

Merged
merged 9 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/json_schema/gen_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_required(self, obj):
FormRecognizerConnectionSchema, CustomConnectionSchema, WeaviateConnectionSchema, ServerlessConnectionSchema, \
CustomStrongTypeConnectionSchema
from promptflow._sdk.schemas._run import RunSchema
from promptflow._sdk.schemas._flow import FlowSchema, EagerFlowSchema
from promptflow._sdk.schemas._flow import FlowSchema, FlexFlowSchema


if __name__ == "__main__":
Expand All @@ -167,7 +167,7 @@ def get_required(self, obj):

# Special case for Flow and EagerFlow
if "Flow" in args.output_file:
cls_list = [FlowSchema, EagerFlowSchema]
cls_list = [FlowSchema, FlexFlowSchema]
schema_list = []
for cls in cls_list:
target_schema = PatchedJSONSchema().dump(cls(context={"base_path": "./"}))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ def _generate_flow_meta(
"""Generate tool meta from files.

:param flow_directory: flow directory
:param tools: tool list
:param raise_error: whether raise error when generate meta failed
:param timeout: timeout for generate meta
:param include_errors_in_output: whether include errors in output
:param load_in_subprocess: whether load tool meta with subprocess to prevent system path disturb. Default is True.
If set to False, will load tool meta in sync mode and timeout need to be handled outside current process.
:return: tool meta dict
Expand Down
94 changes: 75 additions & 19 deletions src/promptflow-core/promptflow/_core/tool_meta_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
This file can generate a meta file for the given prompt template or a python file.
"""
import dataclasses
import importlib.util
import inspect
import json
Expand Down Expand Up @@ -39,8 +40,13 @@
last_frame_info,
)
from promptflow._utils.process_utils import block_terminate_signal_to_parent
from promptflow._utils.tool_utils import asdict_without_none, function_to_interface, get_inputs_for_prompt_template
from promptflow.contracts.tool import Tool, ToolType
from promptflow._utils.tool_utils import (
asdict_without_none,
function_to_interface,
get_inputs_for_prompt_template,
resolve_annotation,
)
from promptflow.contracts.tool import InputDefinition, Tool, ToolType, ValueType
from promptflow.exceptions import ErrorTarget, UserErrorException


Expand Down Expand Up @@ -149,7 +155,11 @@ def collect_tool_methods_with_init_inputs_in_module(m):


def _parse_tool_from_function(
f, initialize_inputs=None, gen_custom_type_conn=False, skip_prompt_template=False, include_outputs=False
f,
initialize_inputs=None,
gen_custom_type_conn=False,
skip_prompt_template=False,
include_outputs=False,
):
try:
tool_type = getattr(f, "__type", None) or ToolType.PYTHON
Expand Down Expand Up @@ -495,33 +505,52 @@ def generate_tool_meta_in_subprocess(
return tool_dict, exception_dict


def generate_flow_meta_dict_by_file(data: dict, source: str = None, path: str = None):
"""Generate flow meta for eager flow data.
Original flow configuration like environment variables will be generated in meta.
"""
def validate_interface(ports, fields, tool_name, port_type):
"""Validate if the interface are serializable."""
for k, v in ports.items():
if ValueType.OBJECT not in v.type:
continue
if resolve_annotation(fields[k]) not in [dict, inspect.Signature.empty]:
raise BadFunctionInterface(
message_format=(
"Parse interface for tool '{tool_name}' failed: "
"The {port_type} '{k}' is of a complex python type. "
"Please use a dict instead."
),
port_type=port_type,
tool_name=tool_name,
k=k,
)

entry = data.get("entry")
if path:
m = load_python_module_from_file(Path(path))
else:
m = load_python_module_from_entry(entry)

f, cls = collect_flow_entry_in_module(m, entry)
def generate_flow_meta_dict_by_object(f, cls):
# Since the flow meta is generated from the entry function, we leverage the function
# _parse_tool_from_function to parse the interface of the entry function to get the inputs and outputs.
tool = _parse_tool_from_function(f, include_outputs=True)

# flow will be deployed as a service finally, so we will raise error if any port is of a complex python type.
# need to update this if we support serializable python object in the future.
signature = inspect.signature(f)
input_fields = {k: v.annotation for k, v in signature.parameters.items()}
validate_interface(tool.inputs, input_fields, tool.name, "input")

# validate output
return_type = resolve_annotation(signature.return_annotation)
if dataclasses.is_dataclass(return_type):
output_fields = {x.name: x.type for x in dataclasses.fields(return_type)}
else:
output_fields = {"output": return_type}
validate_interface(tool.outputs, output_fields, tool.name, "output")

# Include data in generated meta to avoid flow definition's fields(e.g. environment variable) missing.
flow_meta = {"function": f.__name__, **data}
flow_meta = {}
if cls:
init_tool = _parse_tool_from_function(cls.__init__, include_outputs=False)
init_inputs = init_tool.inputs
# no need to validate init as dict is not allowed in init for now.
else:
init_inputs = None

if source:
flow_meta["source"] = source

for ports, meta_key in [
(tool.inputs, "inputs"),
(tool.outputs, "outputs"),
Expand All @@ -533,13 +562,40 @@ def generate_flow_meta_dict_by_file(data: dict, source: str = None, path: str =
flow_meta[meta_key] = {}
for k, v in ports.items():
# We didn't support specifying multiple types for inputs/outputs/init, so we only take the first one.
flow_meta[meta_key][k] = {"type": v.type[0].value}
param_type = v.type[0]
elliotzh marked this conversation as resolved.
Show resolved Hide resolved
# For connections, param_type is a string; for primitive types, param_type is
# a promptflow.contracts.tool.InputDefinition.
param_type = param_type.value if isinstance(param_type, ValueType) else param_type

flow_meta[meta_key][k] = {"type": param_type}
# init/inputs may have default value
if meta_key != "outputs" and v.default is not None:
if isinstance(v, InputDefinition) and v.default is not None:
flow_meta[meta_key][k]["default"] = v.default
return flow_meta


def generate_flow_meta_dict_by_file(data: dict, source: str = None, path: str = None):
"""Generate flow meta for eager flow data.
Original flow configuration like environment variables will be generated in meta.
"""

entry = data.get("entry")
if path:
m = load_python_module_from_file(Path(path))
else:
m = load_python_module_from_entry(entry)

f, cls = collect_flow_entry_in_module(m, entry)

flow_meta = {"function": f.__name__, **data}
if source:
flow_meta["source"] = source

flow_meta.update(generate_flow_meta_dict_by_object(f, cls))

return flow_meta


class ToolValidationError(UserErrorException):
"""Base exception raised when failed to validate tool."""

Expand Down
8 changes: 4 additions & 4 deletions src/promptflow-core/promptflow/contracts/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,8 +922,8 @@ def _replace_with_variant(self, variant_node: Node, variant_tools: list):


@dataclass
class EagerFlow(FlowBase):
"""This class represents an eager flow.
class FlexFlow(FlowBase):
"""This class represents a flex flow.

:param id: The id of the flow.
:type id: str
Expand All @@ -947,7 +947,7 @@ class EagerFlow(FlowBase):
message_format: str = MessageFormatType.BASIC

@staticmethod
def deserialize(data: dict) -> "EagerFlow":
def deserialize(data: dict) -> "FlexFlow":
"""Deserialize the flow from a dict.

:param data: The dict to be deserialized.
Expand All @@ -958,7 +958,7 @@ def deserialize(data: dict) -> "EagerFlow":

inputs = data.get("inputs") or {}
outputs = data.get("outputs") or {}
return EagerFlow(
return FlexFlow(
id=data.get("id", "default_flow_id"),
name=data.get("name", "default_flow"),
inputs={name: FlowInputDefinition.deserialize(i) for name, i in inputs.items()},
Expand Down
2 changes: 1 addition & 1 deletion src/promptflow-core/promptflow/core/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def init_executable(*, flow_dag: dict = None, flow_path: Path = None, working_di
if not working_dir:
working_dir = flow_dir

from promptflow.contracts.flow import EagerFlow as ExecutableEagerFlow
from promptflow.contracts.flow import FlexFlow as ExecutableEagerFlow
from promptflow.contracts.flow import Flow as ExecutableFlow

if is_flex_flow(yaml_dict=flow_dag):
Expand Down
10 changes: 5 additions & 5 deletions src/promptflow-devkit/promptflow/_sdk/entities/_flows/flex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from promptflow._utils.flow_utils import resolve_entry_file
from promptflow.exceptions import ErrorTarget, UserErrorException

from .dag import Flow
from .base import Flow as FlowBase


class FlexFlow(Flow):
class FlexFlow(FlowBase):
"""A FlexFlow represents a flow defined with codes directly. It doesn't involve a directed acyclic graph (DAG)
explicitly, but its entry function haven't been provided.
FlexFlow basically behave like a Flow.
Expand Down Expand Up @@ -71,9 +71,9 @@ def _load(cls, path: Path, data: dict, raise_error=True, **kwargs):
@classmethod
def _create_schema_for_validation(cls, context):
# import here to avoid circular import
from promptflow._sdk.schemas._flow import EagerFlowSchema
from promptflow._sdk.schemas._flow import FlexFlowSchema

return EagerFlowSchema(context=context)
return FlexFlowSchema(context=context)

def _default_context(self) -> dict:
return {BASE_PATH_CONTEXT_KEY: self.code}
Expand Down Expand Up @@ -111,7 +111,7 @@ def _resolve_entry_file(cls, entry: str, working_dir: Path) -> Optional[str]:

def _init_executable(self, **kwargs):
from promptflow._proxy import ProxyFactory
from promptflow.contracts.flow import EagerFlow as ExecutableEagerFlow
from promptflow.contracts.flow import FlexFlow as ExecutableEagerFlow

meta_dict = (
ProxyFactory()
Expand Down
Loading
Loading