From f0175c4e15091db931b2a926d6f607dbf3994246 Mon Sep 17 00:00:00 2001 From: Long Chen Date: Mon, 23 Dec 2024 10:54:22 +0800 Subject: [PATCH] fix: set USE_DOCSTRING as default for ai_callable (#1266) --- .changeset/nasty-rings-wave.md | 5 ++++ .../livekit/agents/llm/function_context.py | 14 +++++------ tests/test_create_func.py | 23 +++++++++++++++++-- 3 files changed, 32 insertions(+), 10 deletions(-) create mode 100644 .changeset/nasty-rings-wave.md diff --git a/.changeset/nasty-rings-wave.md b/.changeset/nasty-rings-wave.md new file mode 100644 index 000000000..cbbcb7979 --- /dev/null +++ b/.changeset/nasty-rings-wave.md @@ -0,0 +1,5 @@ +--- +"livekit-agents": patch +--- + +set USE_DOCSTRING as default for ai_callable diff --git a/livekit-agents/livekit/agents/llm/function_context.py b/livekit-agents/livekit/agents/llm/function_context.py index 4470492fe..59604fc8d 100644 --- a/livekit-agents/livekit/agents/llm/function_context.py +++ b/livekit-agents/livekit/agents/llm/function_context.py @@ -105,7 +105,7 @@ class CalledFunction: def ai_callable( *, name: str | None = None, - description: str | _UseDocMarker | None = None, + description: str | _UseDocMarker = USE_DOCSTRING, auto_retry: bool = False, ) -> Callable: def deco(f): @@ -127,7 +127,7 @@ def ai_callable( self, *, name: str | None = None, - description: str | _UseDocMarker | None = None, + description: str | _UseDocMarker = USE_DOCSTRING, auto_retry: bool = True, ) -> Callable: def deco(f): @@ -243,19 +243,17 @@ def _extract_types(annotation: type) -> tuple[type, TypeInfo | None]: def _set_metadata( f: Callable, name: str | None = None, - desc: str | _UseDocMarker | None = None, + desc: str | _UseDocMarker = USE_DOCSTRING, auto_retry: bool = False, ) -> None: - if desc is None: - desc = "" - if isinstance(desc, _UseDocMarker): - desc = inspect.getdoc(f) - if desc is None: + docstring = inspect.getdoc(f) + if docstring is None: raise ValueError( f"missing docstring for function {f.__name__}, " "use explicit description or provide docstring" ) + desc = docstring metadata = _AIFncMetadata( name=name or f.__name__, description=desc, auto_retry=auto_retry diff --git a/tests/test_create_func.py b/tests/test_create_func.py index 97583fb36..a81d31d93 100644 --- a/tests/test_create_func.py +++ b/tests/test_create_func.py @@ -43,11 +43,15 @@ def test_fn( def test_func_duplicate(): class TestFunctionContext(llm.FunctionContext): - @llm.ai_callable(name="duplicate_function") + @llm.ai_callable( + name="duplicate_function", description="A simple test function" + ) def fn1(self): pass - @llm.ai_callable(name="duplicate_function") + @llm.ai_callable( + name="duplicate_function", description="A simple test function" + ) def fn2(self): pass @@ -57,6 +61,21 @@ def fn2(self): TestFunctionContext() +def test_func_with_docstring(): + class TestFunctionContext(llm.FunctionContext): + @llm.ai_callable() + def test_fn(self): + """A simple test function""" + pass + + fnc_ctx = TestFunctionContext() + assert ( + "test_fn" in fnc_ctx.ai_functions + ), "Function should be registered in ai_functions" + + assert fnc_ctx.ai_functions["test_fn"].description == "A simple test function" + + def test_func_with_optional_parameter(): class TestFunctionContext(llm.FunctionContext): @llm.ai_callable(