Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Executor] Gen flow meta for eager flow #2027

Merged
merged 12 commits into from
Feb 22, 2024
42 changes: 40 additions & 2 deletions src/promptflow/promptflow/_core/tool_meta_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ def collect_tool_functions_in_module(m):
return tools


def collect_flow_entry_in_module(m, entry):
entry = entry.split(":")[-1]
func = getattr(m, entry, None)
if isinstance(func, types.FunctionType):
return func
raise PythonParsingError(
message_format="Failed to collect flow entry '{entry}' in module '{module}'.",
entry=entry,
module=m.__name__,
)


def collect_tool_methods_in_module(m):
tools = []
for _, obj in inspect.getmembers(m):
Expand All @@ -120,7 +132,9 @@ def collect_tool_methods_with_init_inputs_in_module(m):
return tools


def _parse_tool_from_function(f, initialize_inputs=None, gen_custom_type_conn=False, skip_prompt_template=False):
def _parse_tool_from_function(
f, initialize_inputs=None, gen_custom_type_conn=False, skip_prompt_template=False, include_outputs=False
):
try:
tool_type = getattr(f, "__type", None) or ToolType.PYTHON
except Exception as e:
Expand All @@ -132,7 +146,7 @@ def _parse_tool_from_function(f, initialize_inputs=None, gen_custom_type_conn=Fa
if hasattr(f, "__original_function"):
f = f.__original_function
try:
inputs, _, _, enable_kwargs = function_to_interface(
inputs, outputs, _, enable_kwargs = function_to_interface(
f,
initialize_inputs=initialize_inputs,
gen_custom_type_conn=gen_custom_type_conn,
Expand All @@ -153,6 +167,7 @@ def _parse_tool_from_function(f, initialize_inputs=None, gen_custom_type_conn=Fa
name=tool_name or f.__qualname__,
description=description or inspect.getdoc(f),
inputs=inputs,
outputs=outputs if include_outputs else None,
type=tool_type,
class_name=class_name,
function=f.__name__,
Expand Down Expand Up @@ -310,6 +325,29 @@ def generate_tool_meta_dict_by_file(path: str, tool_type: ToolType):
)


def generate_flow_meta_dict_by_file(path: str, entry: str, source: str = None):
m = load_python_module_from_file(Path(path))
f = collect_flow_entry_in_module(m, entry)
# Since the flow meta is generated from the entry function, we leverage the function
# _parse_tool_from_function to parse the interface of the entry function to get the inputs and outputs.
tool = _parse_tool_from_function(f, include_outputs=True)

flow_meta = {"entry": entry, "function": f.__name__}
if source:
flow_meta["source"] = source
if tool.inputs:
D-W- marked this conversation as resolved.
Show resolved Hide resolved
flow_meta["inputs"] = {}
for k, v in tool.inputs.items():
# We didn't support specifying multiple types for inputs, so we only take the first one.
flow_meta["inputs"][k] = {"type": v.type[0].value}
if tool.outputs:
D-W- marked this conversation as resolved.
Show resolved Hide resolved
flow_meta["outputs"] = {}
for k, v in tool.outputs.items():
# We didn't support specifying multiple types for outputs, so we only take the first one.
flow_meta["outputs"][k] = {"type": v.type[0].value}
return flow_meta


class ToolValidationError(UserErrorException):
"""Base exception raised when failed to validate tool."""

Expand Down
21 changes: 18 additions & 3 deletions src/promptflow/promptflow/_utils/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@
from promptflow._utils.utils import is_json_serializable
from promptflow.exceptions import ErrorTarget, UserErrorException

from ..contracts.tool import ConnectionType, InputDefinition, Tool, ToolFuncCallScenario, ToolType, ValueType
from ..contracts.tool import (
ConnectionType,
InputDefinition,
OutputDefinition,
Tool,
ToolFuncCallScenario,
ToolType,
ValueType
)
from ..contracts.types import PromptTemplate

module_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -142,8 +150,15 @@ def function_to_interface(
input_defs[k] = input_def
if is_connection:
connection_types.append(input_def.type)
outputs = {}
# Note: We don't have output definition now
# Resolve output to definition
typ = resolve_annotation(sign.return_annotation)
if typ is inspect.Signature.empty:
output_type = [ValueType.OBJECT]
else:
# If the output annotation is a union type, then it should be a list.
output_type = [ValueType.from_type(t) for t in typ] if isinstance(typ, list) else [ValueType.from_type(typ)]
lumoslnt marked this conversation as resolved.
Show resolved Hide resolved
outputs = {"output": OutputDefinition(type=output_type)}

return input_defs, outputs, connection_types, enable_kwargs


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
NoToolDefined,
PythonLoadError,
PythonParsingError,
generate_flow_meta_dict_by_file,
generate_prompt_meta,
generate_python_meta,
generate_tool_meta_dict_by_file,
)
from promptflow._utils.context_utils import _change_working_dir
from promptflow._utils.exception_utils import ExceptionPresenter

from ...utils import FLOW_ROOT, load_json
from ...utils import EAGER_FLOW_ROOT, FLOW_ROOT, load_json

TEST_ROOT = Path(__file__).parent.parent.parent.parent
TOOLS_ROOT = TEST_ROOT / "test_configs/wrong_tools"
Expand All @@ -34,6 +36,14 @@ def cd_and_run(working_dir, source_path, tool_type):
return f"({e.__class__.__name__}) {e}"


def cd_and_run_generate_flow_meta(working_dir, source_path, entry, source=None):
with _change_working_dir(working_dir):
try:
return generate_flow_meta_dict_by_file(source_path, entry, source)
except Exception as e:
return f"({e.__class__.__name__}) {e}"


def cd_and_run_with_read_text_error(working_dir, source_path, tool_type):
def mock_read_text_error(self: Path, *args, **kwargs):
raise Exception("Mock read text error.")
Expand Down Expand Up @@ -87,6 +97,20 @@ def test_generate_tool_meta_dict_by_file(self, flow_dir, tool_path, tool_type):
expected_dict["type"] = "llm" # We use prompt as default for jinja2
assert meta_dict == expected_dict

@pytest.mark.parametrize(
"flow_dir, entry_path, entry",
[
("dummy_flow_with_trace", "flow_with_trace.py", "flow_with_trace:my_flow"),
]
)
def test_generate_flow_meta(self, flow_dir, entry_path, entry):
wd = str((EAGER_FLOW_ROOT / flow_dir).resolve())
meta_dict = cd_and_run_generate_flow_meta(wd, entry_path, entry, source=entry_path)
assert isinstance(meta_dict, dict), "Call cd_and_run_generate_flow_meta failed:\n" + meta_dict
target_file = (Path(wd) / entry_path).with_suffix(".meta.json")
expected_dict = load_json(target_file)
assert meta_dict == expected_dict

@pytest.mark.parametrize(
"flow_dir, tool_path, tool_type, func, msg_pattern",
[
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"function": "my_flow",
"entry": "flow_with_trace:my_flow",
"inputs": {
"text": {
"type": "string"
},
"models": {
"type": "list"
}
},
"outputs": {
"output": {
"type": "string"
}
},
"source": "flow_with_trace.py"
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def dummy_llm(prompt: str, model: str, wait_seconds: int):
return prompt


async def my_flow(text: str, models: list = []):
async def my_flow(text: str, models: list = []) -> str:
tasks = []
for i, model in enumerate(models):
tasks.append(asyncio.create_task(dummy_llm(text, model, i + 1)))
Expand Down
Loading