diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 9d16d6392..a55a2a311 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -22,7 +22,7 @@ from .._types import AgentState from ..llm import LLM, ChatContext, ChatMessage, FunctionContext, LLMStream from . import metrics -from .agent_output import AgentOutput, SpeechSource, SynthesisHandle +from .agent_output import AgentOutput, PlayoutHandle, SpeechSource, SynthesisHandle from .agent_playout import AgentPlayout from .human_input import HumanInput from .log import logger @@ -661,66 +661,14 @@ async def _play_speech(self, speech_handle: SpeechHandle) -> None: user_question = speech_handle.user_question play_handle = synthesis_handle.play() - join_fut = play_handle.join() - - def _commit_user_question_if_needed() -> None: - if ( - not user_question - or synthesis_handle.interrupted - or speech_handle.user_commited - ): - return - - is_using_tools = isinstance(speech_handle.source, LLMStream) and len( - speech_handle.source.function_calls - ) - # make sure at least some speech was played before committing the user message - # since we try to validate as fast as possible it is possible the agent gets interrupted - # really quickly (barely audible), we don't want to mark this question as "answered". - if ( - speech_handle.allow_interruptions - and not is_using_tools - and ( - play_handle.time_played < self.MIN_TIME_PLAYED_FOR_COMMIT - and not join_fut.done() - ) - ): - return - - user_msg = ChatMessage.create(text=user_question, role="user") - self._chat_ctx.messages.append(user_msg) - self.emit("user_speech_committed", user_msg) - - self._transcribed_text = self._transcribed_text[len(user_question) :] - speech_handle.mark_user_commited() - - # wait for the play_handle to finish and check every 1s if the user question should be committed - _commit_user_question_if_needed() - - while not join_fut.done(): - await asyncio.wait( - [join_fut], return_when=asyncio.FIRST_COMPLETED, timeout=0.2 - ) - - _commit_user_question_if_needed() - - if speech_handle.interrupted: - break - - _commit_user_question_if_needed() - - collected_text = speech_handle.synthesis_handle.tts_forwarder.played_text - interrupted = speech_handle.interrupted - is_using_tools = isinstance(speech_handle.source, LLMStream) and len( - speech_handle.source.function_calls - ) + await self._wait_for_play_completion(speech_handle, play_handle) extra_tools_messages = [] # additional messages from the functions to add to the context if needed # if the answer is using tools, execute the functions and automatically generate # a response to the user question from the returned values - if is_using_tools and not interrupted: + if speech_handle.is_using_tools and not speech_handle.interrupted: assert isinstance(speech_handle.source, LLMStream) assert ( not user_question or speech_handle.user_commited @@ -778,7 +726,9 @@ def _commit_user_question_if_needed() -> None: # generate an answer from the tool calls extra_tools_messages.append( - ChatMessage.create_tool_calls(tool_calls_info, text=collected_text) + ChatMessage.create_tool_calls( + tool_calls_info, text=speech_handle.collected_text() + ) ) extra_tools_messages.extend(tool_calls_results) @@ -799,8 +749,6 @@ def _commit_user_question_if_needed() -> None: play_handle = answer_synthesis.play() await play_handle.join() - collected_text = answer_synthesis.tts_forwarder.played_text - interrupted = answer_synthesis.interrupted new_function_calls = answer_llm_stream.function_calls self.emit("function_calls_finished", called_fncs) @@ -815,15 +763,14 @@ def _commit_user_question_if_needed() -> None: ): self._chat_ctx.messages.extend(extra_tools_messages) - if interrupted: - collected_text += "..." - - msg = ChatMessage.create(text=collected_text, role="assistant") + msg = ChatMessage.create( + text=speech_handle.collected_text(), role="assistant" + ) self._chat_ctx.messages.append(msg) speech_handle.mark_speech_commited() - if interrupted: + if speech_handle.interrupted: self.emit("agent_speech_interrupted", msg) else: self.emit("agent_speech_committed", msg) @@ -831,12 +778,64 @@ def _commit_user_question_if_needed() -> None: logger.debug( "committed agent speech", extra={ - "agent_transcript": collected_text, - "interrupted": interrupted, + "agent_transcript": speech_handle.collected_text(), + "interrupted": speech_handle.interrupted, "speech_id": speech_handle.id, }, ) + async def _wait_for_play_completion( + self, speech_handle: SpeechHandle, play_handle: PlayoutHandle + ) -> None: + user_question = speech_handle.user_question + join_fut = play_handle.join() + + def _commit_user_question_if_needed() -> None: + if ( + not user_question + or speech_handle.synthesis_handle.interrupted + or speech_handle.user_commited + ): + return + + # make sure at least some speech was played before committing the user message + # since we try to validate as fast as possible it is possible the agent gets interrupted + # really quickly (barely audible), we don't want to mark this question as "answered". + if ( + speech_handle.allow_interruptions + and not speech_handle.is_using_tools + and ( + play_handle.time_played < self.MIN_TIME_PLAYED_FOR_COMMIT + and not join_fut.done() + ) + ): + return + + logger.debug( + "committed user transcript", extra={"user_transcript": user_question} + ) + user_msg = ChatMessage.create(text=user_question, role="user") + self._chat_ctx.messages.append(user_msg) + self.emit("user_speech_committed", user_msg) + + self._transcribed_text = self._transcribed_text[len(user_question) :] + speech_handle.mark_user_commited() + + # wait for the play_handle to finish and check every 0.5s if the user question should be committed + _commit_user_question_if_needed() + + while not join_fut.done(): + await asyncio.wait( + [join_fut], return_when=asyncio.FIRST_COMPLETED, timeout=0.2 + ) + + _commit_user_question_if_needed() + + if speech_handle.interrupted: + break + + _commit_user_question_if_needed() + def _synthesize_agent_speech( self, speech_id: str, diff --git a/livekit-agents/livekit/agents/pipeline/speech_handle.py b/livekit-agents/livekit/agents/pipeline/speech_handle.py index 194c4973e..ba200316b 100644 --- a/livekit-agents/livekit/agents/pipeline/speech_handle.py +++ b/livekit-agents/livekit/agents/pipeline/speech_handle.py @@ -64,6 +64,11 @@ def create_assistant_speech( user_question="", ) + def collected_text(self) -> str: + if self.interrupted: + return self.synthesis_handle.tts_forwarder.played_text + "..." + return self.synthesis_handle.tts_forwarder.played_text + async def wait_for_initialization(self) -> None: await asyncio.shield(self._init_fut) @@ -87,6 +92,9 @@ def mark_user_commited(self) -> None: def mark_speech_commited(self) -> None: self._speech_commited = True + def is_using_tools(self) -> bool: + return isinstance(self.source, LLMStream) and len(self.source.function_calls) + @property def user_commited(self) -> bool: return self._user_commited