Skip to content

Commit

Permalink
Add enable_kwargs tag in tools.json for customer python tool (#1006)
Browse files Browse the repository at this point in the history
# Description
#714 asked support for
kwargs.

Add `enable_kwargs` tag in tools.json for customer python tool.
Then clients could add corresponding support for kwargs by yaml
authoring.
We don't need to modify execution logic to support kwargs.

Confirmed with @lalala123123 `_parse_tool_from_function` in
`_sdk.operations._tool_operations.py` will not be used to generate meta.

# All Promptflow Contribution checklist:
- [X] **The pull request does not introduce [breaking changes].**
- [X] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [X] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [X] Title of the pull request is clear and informative.
- [X] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [X] Pull request includes test coverage for the included changes.

---------

Co-authored-by: Robben Wang <robbenwang@microsoft.com>
  • Loading branch information
huaiyan and Robben Wang authored Nov 8, 2023
1 parent 6c823a3 commit 264984f
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 97 deletions.
1 change: 1 addition & 0 deletions src/promptflow/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Features Added

- [Executor] Add `enable_kwargs` tag in tools.json for customer python tool.
- [SDK/CLI] Support `pfazure flow create`. Create a flow on Azure AI from local flow folder.
- [SDK/CLI] Changed column mapping `${run.inputs.xx}`'s behavior, it will refer to run's data columns instead of run's inputs columns.

Expand Down
3 changes: 2 additions & 1 deletion src/promptflow/promptflow/_core/tool_meta_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _parse_tool_from_function(f, initialize_inputs=None, gen_custom_type_conn=Fa
if hasattr(f, "__original_function"):
f = f.__original_function
try:
inputs, _, _ = function_to_interface(
inputs, _, _, enable_kwargs = function_to_interface(
f, initialize_inputs=initialize_inputs, gen_custom_type_conn=gen_custom_type_conn,
skip_prompt_template=skip_prompt_template)
except Exception as e:
Expand All @@ -164,6 +164,7 @@ def _parse_tool_from_function(f, initialize_inputs=None, gen_custom_type_conn=Fa
class_name=class_name,
function=f.__name__,
module=f.__module__,
enable_kwargs=enable_kwargs,
)


Expand Down
5 changes: 3 additions & 2 deletions src/promptflow/promptflow/_utils/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def function_to_interface(f: Callable, initialize_inputs=None, gen_custom_type_c
if any(k for k in initialize_inputs if k in sign.parameters):
raise Exception(f'Duplicate inputs found from {f.__name__!r} and "__init__()"!')
all_inputs = {**initialize_inputs}
enable_kwargs = any([param.kind == inspect.Parameter.VAR_KEYWORD for _, param in sign.parameters.items()])
all_inputs.update(
{
k: v
Expand All @@ -138,7 +139,7 @@ def function_to_interface(f: Callable, initialize_inputs=None, gen_custom_type_c
connection_types.append(input_def.type)
outputs = {}
# Note: We don't have output definition now
return input_defs, outputs, connection_types
return input_defs, outputs, connection_types, enable_kwargs


def function_to_tool_definition(f: Callable, type=None, initialize_inputs=None) -> Tool:
Expand All @@ -152,7 +153,7 @@ def function_to_tool_definition(f: Callable, type=None, initialize_inputs=None)
"""
if hasattr(f, "__original_function"):
f = f.__original_function
inputs, outputs, _ = function_to_interface(f, initialize_inputs)
inputs, outputs, _, _ = function_to_interface(f, initialize_inputs)
# Hack to get class name
class_name = None
if "." in f.__qualname__:
Expand Down
4 changes: 4 additions & 0 deletions src/promptflow/promptflow/contracts/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ class Tool:
:type is_builtin: Optional[bool]
:param stage: The stage of the tool
:type stage: Optional[str]
:param enable_kwargs: Whether to enable kwargs, only available for customer python tool
:type enable_kwargs: Optional[bool]
"""

name: str
Expand All @@ -376,6 +378,7 @@ class Tool:
connection_type: Optional[List[str]] = None
is_builtin: Optional[bool] = None
stage: Optional[str] = None
enable_kwargs: Optional[bool] = False

def serialize(self) -> dict:
"""Serialize tool to dict and skip None fields.
Expand Down Expand Up @@ -415,6 +418,7 @@ def deserialize(data: dict) -> "Tool":
connection_type=data.get("connection_type"),
is_builtin=data.get("is_builtin"),
stage=data.get("stage"),
enable_kwargs=data.get("enable_kwargs", False),
)

def _require_connection(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _convert_node_literal_input_types(self, node: Node, tool: Tool, module: type
if v.value_type != InputValueType.LITERAL:
continue
tool_input = tool.inputs.get(k)
if tool_input is None:
if tool_input is None: # For kwargs input, tool_input is None.
continue
value_type = tool_input.type[0]
updated_inputs[k] = InputAssignment(value=v.value, value_type=InputValueType.LITERAL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_function_to_interface(self):
def func(conn: [AzureOpenAIConnection, CustomConnection], input: [str, int]):
pass

input_defs, _, connection_types = function_to_interface(func)
input_defs, _, connection_types, _ = function_to_interface(func)
assert len(input_defs) == 2
assert input_defs["conn"].type == ["AzureOpenAIConnection", "CustomConnection"]
assert input_defs["input"].type == [ValueType.OBJECT]
Expand All @@ -69,6 +69,19 @@ def func(input_str: str):
function_to_interface(func, {"input_str": "test"})
assert "Duplicate inputs found from" in exec_info.value.args[0]

def test_function_to_interface_with_kwargs(self):
def func(input_str: str, **kwargs):
pass

_, _, _, enable_kwargs = function_to_interface(func)
assert enable_kwargs is True

def func(input_str: str):
pass

_, _, _, enable_kwargs = function_to_interface(func)
assert enable_kwargs is False

def test_param_to_definition(self):
from promptflow._sdk.entities import CustomStrongTypeConnection
from promptflow.contracts.tool import Secret
Expand Down
140 changes: 48 additions & 92 deletions src/promptflow/tests/sdk_cli_test/e2etests/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from pathlib import Path

import pytest

from promptflow._core.tool import tool
from promptflow.entities import DynamicList, InputSetting
from promptflow._sdk._pf_client import PFClient
from promptflow.entities import DynamicList, InputSetting
from promptflow.exceptions import UserErrorException

PROMOTFLOW_ROOT = Path(__file__) / "../../../.."
Expand Down Expand Up @@ -34,28 +35,21 @@ def test_python_tool_meta(self):
"test_tool.python_tool.PythonTool.python_tool": {
"class_name": "PythonTool",
"function": "python_tool",
"inputs": {
"connection": {"type": ["AzureOpenAIConnection"]},
"input1": {"type": ["string"]}
},
"inputs": {"connection": {"type": ["AzureOpenAIConnection"]}, "input1": {"type": ["string"]}},
"module": "test_tool.python_tool",
"name": "PythonTool.python_tool",
"type": "python",
},
"test_tool.python_tool.my_python_tool": {
"function": "my_python_tool",
"inputs": {
"input1": {"type": ["string"]}
},
"inputs": {"input1": {"type": ["string"]}},
"module": "test_tool.python_tool",
"name": "python_tool",
"type": "python",
},
"test_tool.python_tool.my_python_tool_without_name": {
"function": "my_python_tool_without_name",
"inputs": {
"input1": {"type": ["string"]}
},
"inputs": {"input1": {"type": ["string"]}},
"module": "test_tool.python_tool",
"name": "my_python_tool_without_name",
"type": "python",
Expand All @@ -67,33 +61,31 @@ def test_llm_tool_meta(self):
tool_path = TOOL_ROOT / "custom_llm_tool.py"
tool_meta = self.get_tool_meta(tool_path)
expect_tool_meta = {
'test_tool.custom_llm_tool.my_tool': {
'name': 'My Custom LLM Tool',
'type': 'custom_llm',
'inputs': {
'connection': {'type': ['CustomConnection']}
},
'description': 'This is a tool to demonstrate the custom_llm tool type',
'module': 'test_tool.custom_llm_tool',
'function': 'my_tool'
"test_tool.custom_llm_tool.my_tool": {
"name": "My Custom LLM Tool",
"type": "custom_llm",
"inputs": {"connection": {"type": ["CustomConnection"]}},
"description": "This is a tool to demonstrate the custom_llm tool type",
"module": "test_tool.custom_llm_tool",
"function": "my_tool",
"enable_kwargs": True,
},
"test_tool.custom_llm_tool.TestCustomLLMTool.tool_func": {
"name": "My Custom LLM Tool",
"type": "custom_llm",
"inputs": {"connection": {"type": ["AzureOpenAIConnection"]}, "api": {"type": ["string"]}},
"description": "This is a tool to demonstrate the custom_llm tool type",
"module": "test_tool.custom_llm_tool",
"class_name": "TestCustomLLMTool",
"function": "tool_func",
"enable_kwargs": True,
},
'test_tool.custom_llm_tool.TestCustomLLMTool.tool_func': {
'name': 'My Custom LLM Tool',
'type': 'custom_llm',
'inputs': {
'connection': {'type': ['AzureOpenAIConnection']},
'api': {'type': ['string']}
},
'description': 'This is a tool to demonstrate the custom_llm tool type',
'module': 'test_tool.custom_llm_tool',
'class_name': 'TestCustomLLMTool',
'function': 'tool_func'
}
}
assert tool_meta == expect_tool_meta

def test_invalid_tool_type(self):
with pytest.raises(UserErrorException) as exception:

@tool(name="invalid_tool_type", type="invalid_type")
def invalid_tool_type():
pass
Expand All @@ -107,10 +99,7 @@ def test_tool_with_custom_connection(self):
"test_tool.tool_with_custom_connection.MyTool.my_tool": {
"name": "My Second Tool",
"type": "python",
"inputs": {
"connection": {"type": ["CustomConnection"]},
"input_text": {"type": ["string"]}
},
"inputs": {"connection": {"type": ["CustomConnection"]}, "input_text": {"type": ["string"]}},
"description": "This is my second tool",
"module": "test_tool.tool_with_custom_connection",
"class_name": "MyTool",
Expand Down Expand Up @@ -158,12 +147,7 @@ def test_tool_with_input_settings(self):
"optional": True,
"default": "",
},
{
"name": "size",
"type": ["int"],
"optional": True,
"default": 10
},
{"name": "size", "type": ["int"], "optional": True, "default": 10},
],
},
},
Expand All @@ -183,37 +167,13 @@ def test_tool_with_input_settings(self):
"name": "My Tool with Enabled By Value",
"type": "python",
"inputs": {
"user_type": {
"type": [
"string"
],
"enum": [
"student",
"teacher"
]
},
"student_id": {
"type": [
"string"
],
"enabled_by": "user_type",
"enabled_by_value": [
"student"
]
},
"teacher_id": {
"type": [
"string"
],
"enabled_by": "user_type",
"enabled_by_value": [
"teacher"
]
}
"user_type": {"type": ["string"], "enum": ["student", "teacher"]},
"student_id": {"type": ["string"], "enabled_by": "user_type", "enabled_by_value": ["student"]},
"teacher_id": {"type": ["string"], "enabled_by": "user_type", "enabled_by_value": ["teacher"]},
},
"description": "This is my tool with enabled by value",
"module": "test_tool.tool_with_enabled_by_value",
"function": "my_tool"
"function": "my_tool",
}
}
assert tool_meta == expect_tool_meta
Expand All @@ -226,16 +186,14 @@ def my_list_func(prefix: str, size: int = 10):
invalid_dynamic_list_setting = DynamicList(function=my_list_func, input_mapping={"prefix": "invalid_input"})
input_settings = {
"input_text": InputSetting(
dynamic_list=invalid_dynamic_list_setting,
allow_manual_entry=True,
is_multi_select=True
dynamic_list=invalid_dynamic_list_setting, allow_manual_entry=True, is_multi_select=True
)
}

@tool(
name="My Tool with Dynamic List Input",
description="This is my tool with dynamic list input",
input_settings=input_settings
input_settings=input_settings,
)
def my_tool(input_text: list, input_prefix: str) -> str:
return f"Hello {input_prefix} {','.join(input_text)}"
Expand All @@ -246,19 +204,18 @@ def my_tool(input_text: list, input_prefix: str) -> str:

# invalid dynamic func input
invalid_dynamic_list_setting = DynamicList(
function=my_list_func, input_mapping={"invalid_input": "input_prefix"})
function=my_list_func, input_mapping={"invalid_input": "input_prefix"}
)
input_settings = {
"input_text": InputSetting(
dynamic_list=invalid_dynamic_list_setting,
allow_manual_entry=True,
is_multi_select=True
dynamic_list=invalid_dynamic_list_setting, allow_manual_entry=True, is_multi_select=True
)
}

@tool(
name="My Tool with Dynamic List Input",
description="This is my tool with dynamic list input",
input_settings=input_settings
input_settings=input_settings,
)
def my_tool(input_text: list, input_prefix: str) -> str:
return f"Hello {input_prefix} {','.join(input_text)}"
Expand All @@ -270,13 +227,15 @@ def my_tool(input_text: list, input_prefix: str) -> str:
# check required inputs of dynamic list func
invalid_dynamic_list_setting = DynamicList(function=my_list_func, input_mapping={"size": "input_prefix"})
input_settings = {
"input_text": InputSetting(dynamic_list=invalid_dynamic_list_setting, )
"input_text": InputSetting(
dynamic_list=invalid_dynamic_list_setting,
)
}

@tool(
name="My Tool with Dynamic List Input",
description="This is my tool with dynamic list input",
input_settings=input_settings
input_settings=input_settings,
)
def my_tool(input_text: list, input_prefix: str) -> str:
return f"Hello {input_prefix} {','.join(input_text)}"
Expand All @@ -295,22 +254,19 @@ def enabled_by_with_invalid_input(input1: str, input2: str):

with pytest.raises(UserErrorException) as exception:
_client._tools._serialize_tool(enabled_by_with_invalid_input)
assert "Cannot find the input \"invalid_input\"" in exception.value.message
assert 'Cannot find the input "invalid_input"' in exception.value.message

def test_tool_with_file_path_input(self):
tool_path = TOOL_ROOT / "tool_with_file_path_input.py"
tool_meta = self.get_tool_meta(tool_path)
expect_tool_meta = {
'test_tool.tool_with_file_path_input.my_tool': {
'name': 'Tool with FilePath Input',
'type': 'python',
'inputs': {
'input_file': {'type': ['file_path']},
'input_text': {'type': ['string']}
},
'description': 'This is a tool to demonstrate the usage of FilePath input',
'module': 'test_tool.tool_with_file_path_input',
'function': 'my_tool'
"test_tool.tool_with_file_path_input.my_tool": {
"name": "Tool with FilePath Input",
"type": "python",
"inputs": {"input_file": {"type": ["file_path"]}, "input_text": {"type": ["string"]}},
"description": "This is a tool to demonstrate the usage of FilePath input",
"module": "test_tool.tool_with_file_path_input",
"function": "my_tool",
}
}
assert expect_tool_meta == tool_meta

0 comments on commit 264984f

Please sign in to comment.