Skip to content

Commit

Permalink
feat: do validate in signature inference
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotzh committed Apr 7, 2024
1 parent a398947 commit a459064
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ def _merge_signature(extracted, signature_overrides):
return signature

@monitor_operation(activity_name="pf.flows._infer_signature", activity_type=ActivityType.INTERNALCALL)
def _infer_signature(self, entry: Callable, keep_entry: bool = False) -> Tuple[dict, Path]:
def _infer_signature(self, entry: Callable, keep_entry: bool = False, validate: bool = True) -> Tuple[dict, Path]:
"""Infer signature of a flow entry.
Note that this is a Python only feature.
Expand All @@ -1047,9 +1047,13 @@ def _infer_signature(self, entry: Callable, keep_entry: bool = False) -> Tuple[d

flow_meta = generate_flow_meta_dict_by_object(func, cls)
source_path = Path(inspect.getfile(entry))
if keep_entry:
# TODO: should we handle the case that entry is not defined in root level of the source?
flow_meta["entry"] = f"{source_path.stem}:{entry.__name__}"
# TODO: should we handle the case that entry is not defined in root level of the source?
flow_meta["entry"] = f"{source_path.stem}:{entry.__name__}"
if validate:
flow = FlexFlow(path=source_path, code=source_path.parent, data=flow_meta, entry=flow_meta["entry"])
flow._validate(raise_error=True)
if not keep_entry:
del flow_meta["entry"]
return flow_meta, source_path.parent

@monitor_operation(activity_name="pf.flows._save", activity_type=ActivityType.INTERNALCALL)
Expand Down Expand Up @@ -1113,7 +1117,7 @@ def _save(
"Code path will be the parent of entry source " "and can't be customized when entry is a callable."
)
else:
entry_meta, code = self._infer_signature(entry, keep_entry=True)
entry_meta, code = self._infer_signature(entry, keep_entry=True, validate=False)

data = self._merge_signature(entry_meta, signature)
data["entry"] = entry_meta["entry"]
Expand Down
29 changes: 28 additions & 1 deletion src/promptflow/tests/sdk_cli_test/e2etests/test_flow_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def __call__(self, text: str) -> str:
return f"Hello {text} via {self.connection.name}!"


class GlobalHelloWithInvalidInit:
def __init__(self, connection: AzureOpenAIConnection, words: list):
self.connection = connection

def __call__(self, text: str) -> str:
return f"Hello {text} via {self.connection.name}!"


def global_hello(text: str) -> str:
return f"Hello {text}!"

Expand Down Expand Up @@ -377,7 +385,7 @@ def test_pf_save_callable_class(self):
},
}

def test_pf_save_callable_object(self):
def test_pf_save_callable_function(self):
pf = PFClient()
target_path = f"{FLOWS_DIR}/saved/hello_callable"
if os.path.exists(target_path):
Expand All @@ -401,3 +409,22 @@ def test_pf_save_callable_object(self):
},
},
}

def test_infer_signature(self):
pf = PFClient()
flow_meta, code = pf.flows._infer_signature(entry=global_hello)
assert flow_meta == {
"inputs": {
"text": {
"type": "string",
}
},
"outputs": {
"output": {
"type": "string",
},
},
}

with pytest.raises(UserErrorException, match="Schema validation failed: {'init.words.type'"):
pf.flows._infer_signature(entry=GlobalHelloWithInvalidInit)

0 comments on commit a459064

Please sign in to comment.