From a45906494ff178377ef23afe2f5468072f674b2c Mon Sep 17 00:00:00 2001 From: zhangxingzhi Date: Sun, 7 Apr 2024 17:09:52 +0800 Subject: [PATCH] feat: do validate in signature inference --- .../_sdk/operations/_flow_operations.py | 14 +++++---- .../sdk_cli_test/e2etests/test_flow_save.py | 29 ++++++++++++++++++- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/promptflow-devkit/promptflow/_sdk/operations/_flow_operations.py b/src/promptflow-devkit/promptflow/_sdk/operations/_flow_operations.py index 13dd6e670d51..a313a8a08497 100644 --- a/src/promptflow-devkit/promptflow/_sdk/operations/_flow_operations.py +++ b/src/promptflow-devkit/promptflow/_sdk/operations/_flow_operations.py @@ -1027,7 +1027,7 @@ def _merge_signature(extracted, signature_overrides): return signature @monitor_operation(activity_name="pf.flows._infer_signature", activity_type=ActivityType.INTERNALCALL) - def _infer_signature(self, entry: Callable, keep_entry: bool = False) -> Tuple[dict, Path]: + def _infer_signature(self, entry: Callable, keep_entry: bool = False, validate: bool = True) -> Tuple[dict, Path]: """Infer signature of a flow entry. Note that this is a Python only feature. @@ -1047,9 +1047,13 @@ def _infer_signature(self, entry: Callable, keep_entry: bool = False) -> Tuple[d flow_meta = generate_flow_meta_dict_by_object(func, cls) source_path = Path(inspect.getfile(entry)) - if keep_entry: - # TODO: should we handle the case that entry is not defined in root level of the source? - flow_meta["entry"] = f"{source_path.stem}:{entry.__name__}" + # TODO: should we handle the case that entry is not defined in root level of the source? + flow_meta["entry"] = f"{source_path.stem}:{entry.__name__}" + if validate: + flow = FlexFlow(path=source_path, code=source_path.parent, data=flow_meta, entry=flow_meta["entry"]) + flow._validate(raise_error=True) + if not keep_entry: + del flow_meta["entry"] return flow_meta, source_path.parent @monitor_operation(activity_name="pf.flows._save", activity_type=ActivityType.INTERNALCALL) @@ -1113,7 +1117,7 @@ def _save( "Code path will be the parent of entry source " "and can't be customized when entry is a callable." ) else: - entry_meta, code = self._infer_signature(entry, keep_entry=True) + entry_meta, code = self._infer_signature(entry, keep_entry=True, validate=False) data = self._merge_signature(entry_meta, signature) data["entry"] = entry_meta["entry"] diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_save.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_save.py index 23f222ad2fd2..46144ef6d5dd 100644 --- a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_save.py +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_save.py @@ -38,6 +38,14 @@ def __call__(self, text: str) -> str: return f"Hello {text} via {self.connection.name}!" +class GlobalHelloWithInvalidInit: + def __init__(self, connection: AzureOpenAIConnection, words: list): + self.connection = connection + + def __call__(self, text: str) -> str: + return f"Hello {text} via {self.connection.name}!" + + def global_hello(text: str) -> str: return f"Hello {text}!" @@ -377,7 +385,7 @@ def test_pf_save_callable_class(self): }, } - def test_pf_save_callable_object(self): + def test_pf_save_callable_function(self): pf = PFClient() target_path = f"{FLOWS_DIR}/saved/hello_callable" if os.path.exists(target_path): @@ -401,3 +409,22 @@ def test_pf_save_callable_object(self): }, }, } + + def test_infer_signature(self): + pf = PFClient() + flow_meta, code = pf.flows._infer_signature(entry=global_hello) + assert flow_meta == { + "inputs": { + "text": { + "type": "string", + } + }, + "outputs": { + "output": { + "type": "string", + }, + }, + } + + with pytest.raises(UserErrorException, match="Schema validation failed: {'init.words.type'"): + pf.flows._infer_signature(entry=GlobalHelloWithInvalidInit)