Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fnc calls #6

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 64 additions & 65 deletions livekit-agents/livekit/agents/pipeline/pipeline_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .._types import AgentState
from ..llm import LLM, ChatContext, ChatMessage, FunctionContext, LLMStream
from . import metrics
from .agent_output import AgentOutput, SpeechSource, SynthesisHandle
from .agent_output import AgentOutput, PlayoutHandle, SpeechSource, SynthesisHandle
from .agent_playout import AgentPlayout
from .human_input import HumanInput
from .log import logger
Expand Down Expand Up @@ -661,66 +661,14 @@ async def _play_speech(self, speech_handle: SpeechHandle) -> None:
user_question = speech_handle.user_question

play_handle = synthesis_handle.play()
join_fut = play_handle.join()

def _commit_user_question_if_needed() -> None:
if (
not user_question
or synthesis_handle.interrupted
or speech_handle.user_commited
):
return

is_using_tools = isinstance(speech_handle.source, LLMStream) and len(
speech_handle.source.function_calls
)

# make sure at least some speech was played before committing the user message
# since we try to validate as fast as possible it is possible the agent gets interrupted
# really quickly (barely audible), we don't want to mark this question as "answered".
if (
speech_handle.allow_interruptions
and not is_using_tools
and (
play_handle.time_played < self.MIN_TIME_PLAYED_FOR_COMMIT
and not join_fut.done()
)
):
return

user_msg = ChatMessage.create(text=user_question, role="user")
self._chat_ctx.messages.append(user_msg)
self.emit("user_speech_committed", user_msg)

self._transcribed_text = self._transcribed_text[len(user_question) :]
speech_handle.mark_user_commited()

# wait for the play_handle to finish and check every 1s if the user question should be committed
_commit_user_question_if_needed()

while not join_fut.done():
await asyncio.wait(
[join_fut], return_when=asyncio.FIRST_COMPLETED, timeout=0.2
)

_commit_user_question_if_needed()

if speech_handle.interrupted:
break

_commit_user_question_if_needed()

collected_text = speech_handle.synthesis_handle.tts_forwarder.played_text
interrupted = speech_handle.interrupted
is_using_tools = isinstance(speech_handle.source, LLMStream) and len(
speech_handle.source.function_calls
)
await self._wait_for_play_completion(speech_handle, play_handle)

extra_tools_messages = [] # additional messages from the functions to add to the context if needed

# if the answer is using tools, execute the functions and automatically generate
# a response to the user question from the returned values
if is_using_tools and not interrupted:
if speech_handle.is_using_tools and not speech_handle.interrupted:
assert isinstance(speech_handle.source, LLMStream)
assert (
not user_question or speech_handle.user_commited
Expand Down Expand Up @@ -778,7 +726,9 @@ def _commit_user_question_if_needed() -> None:

# generate an answer from the tool calls
extra_tools_messages.append(
ChatMessage.create_tool_calls(tool_calls_info, text=collected_text)
ChatMessage.create_tool_calls(
tool_calls_info, text=speech_handle.collected_text()
)
)
extra_tools_messages.extend(tool_calls_results)

Expand All @@ -799,8 +749,6 @@ def _commit_user_question_if_needed() -> None:
play_handle = answer_synthesis.play()
await play_handle.join()

collected_text = answer_synthesis.tts_forwarder.played_text
interrupted = answer_synthesis.interrupted
new_function_calls = answer_llm_stream.function_calls

self.emit("function_calls_finished", called_fncs)
Expand All @@ -815,28 +763,79 @@ def _commit_user_question_if_needed() -> None:
):
self._chat_ctx.messages.extend(extra_tools_messages)

if interrupted:
collected_text += "..."

msg = ChatMessage.create(text=collected_text, role="assistant")
msg = ChatMessage.create(
text=speech_handle.collected_text(), role="assistant"
)
self._chat_ctx.messages.append(msg)

speech_handle.mark_speech_commited()

if interrupted:
if speech_handle.interrupted:
self.emit("agent_speech_interrupted", msg)
else:
self.emit("agent_speech_committed", msg)

logger.debug(
"committed agent speech",
extra={
"agent_transcript": collected_text,
"interrupted": interrupted,
"agent_transcript": speech_handle.collected_text(),
"interrupted": speech_handle.interrupted,
"speech_id": speech_handle.id,
},
)

async def _wait_for_play_completion(
self, speech_handle: SpeechHandle, play_handle: PlayoutHandle
) -> None:
user_question = speech_handle.user_question
join_fut = play_handle.join()

def _commit_user_question_if_needed() -> None:
if (
not user_question
or speech_handle.synthesis_handle.interrupted
or speech_handle.user_commited
):
return

# make sure at least some speech was played before committing the user message
# since we try to validate as fast as possible it is possible the agent gets interrupted
# really quickly (barely audible), we don't want to mark this question as "answered".
if (
speech_handle.allow_interruptions
and not speech_handle.is_using_tools
and (
play_handle.time_played < self.MIN_TIME_PLAYED_FOR_COMMIT
and not join_fut.done()
)
):
return

logger.debug(
"committed user transcript", extra={"user_transcript": user_question}
)
user_msg = ChatMessage.create(text=user_question, role="user")
self._chat_ctx.messages.append(user_msg)
self.emit("user_speech_committed", user_msg)

self._transcribed_text = self._transcribed_text[len(user_question) :]
speech_handle.mark_user_commited()

# wait for the play_handle to finish and check every 0.5s if the user question should be committed
_commit_user_question_if_needed()

while not join_fut.done():
await asyncio.wait(
[join_fut], return_when=asyncio.FIRST_COMPLETED, timeout=0.2
)

_commit_user_question_if_needed()

if speech_handle.interrupted:
break

_commit_user_question_if_needed()

def _synthesize_agent_speech(
self,
speech_id: str,
Expand Down
8 changes: 8 additions & 0 deletions livekit-agents/livekit/agents/pipeline/speech_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def create_assistant_speech(
user_question="",
)

def collected_text(self) -> str:
if self.interrupted:
return self.synthesis_handle.tts_forwarder.played_text + "..."
return self.synthesis_handle.tts_forwarder.played_text

async def wait_for_initialization(self) -> None:
await asyncio.shield(self._init_fut)

Expand All @@ -87,6 +92,9 @@ def mark_user_commited(self) -> None:
def mark_speech_commited(self) -> None:
self._speech_commited = True

def is_using_tools(self) -> bool:
return isinstance(self.source, LLMStream) and len(self.source.function_calls)

@property
def user_commited(self) -> bool:
return self._user_commited
Expand Down
Loading