Skip to content

Commit

Permalink
fix: set USE_DOCSTRING as default for ai_callable (#1266)
Browse files Browse the repository at this point in the history
  • Loading branch information
longcw authored Dec 23, 2024
1 parent 4b72303 commit f0175c4
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
5 changes: 5 additions & 0 deletions .changeset/nasty-rings-wave.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

set USE_DOCSTRING as default for ai_callable
14 changes: 6 additions & 8 deletions livekit-agents/livekit/agents/llm/function_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class CalledFunction:
def ai_callable(
*,
name: str | None = None,
description: str | _UseDocMarker | None = None,
description: str | _UseDocMarker = USE_DOCSTRING,
auto_retry: bool = False,
) -> Callable:
def deco(f):
Expand All @@ -127,7 +127,7 @@ def ai_callable(
self,
*,
name: str | None = None,
description: str | _UseDocMarker | None = None,
description: str | _UseDocMarker = USE_DOCSTRING,
auto_retry: bool = True,
) -> Callable:
def deco(f):
Expand Down Expand Up @@ -243,19 +243,17 @@ def _extract_types(annotation: type) -> tuple[type, TypeInfo | None]:
def _set_metadata(
f: Callable,
name: str | None = None,
desc: str | _UseDocMarker | None = None,
desc: str | _UseDocMarker = USE_DOCSTRING,
auto_retry: bool = False,
) -> None:
if desc is None:
desc = ""

if isinstance(desc, _UseDocMarker):
desc = inspect.getdoc(f)
if desc is None:
docstring = inspect.getdoc(f)
if docstring is None:
raise ValueError(
f"missing docstring for function {f.__name__}, "
"use explicit description or provide docstring"
)
desc = docstring

metadata = _AIFncMetadata(
name=name or f.__name__, description=desc, auto_retry=auto_retry
Expand Down
23 changes: 21 additions & 2 deletions tests/test_create_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,15 @@ def test_fn(

def test_func_duplicate():
class TestFunctionContext(llm.FunctionContext):
@llm.ai_callable(name="duplicate_function")
@llm.ai_callable(
name="duplicate_function", description="A simple test function"
)
def fn1(self):
pass

@llm.ai_callable(name="duplicate_function")
@llm.ai_callable(
name="duplicate_function", description="A simple test function"
)
def fn2(self):
pass

Expand All @@ -57,6 +61,21 @@ def fn2(self):
TestFunctionContext()


def test_func_with_docstring():
class TestFunctionContext(llm.FunctionContext):
@llm.ai_callable()
def test_fn(self):
"""A simple test function"""
pass

fnc_ctx = TestFunctionContext()
assert (
"test_fn" in fnc_ctx.ai_functions
), "Function should be registered in ai_functions"

assert fnc_ctx.ai_functions["test_fn"].description == "A simple test function"


def test_func_with_optional_parameter():
class TestFunctionContext(llm.FunctionContext):
@llm.ai_callable(
Expand Down

0 comments on commit f0175c4

Please sign in to comment.