From f65c990afc2c38c9ebca8a5e3c63267ccee1bb16 Mon Sep 17 00:00:00 2001 From: Long Chen Date: Sun, 15 Dec 2024 12:40:05 +0800 Subject: [PATCH 01/13] feat: multi stage agent testing --- .../voice-pipeline-agent/multi_stage_agent.py | 227 ++++++++++++++++++ .../livekit/agents/pipeline/pipeline_agent.py | 3 +- 2 files changed, 229 insertions(+), 1 deletion(-) create mode 100644 examples/voice-pipeline-agent/multi_stage_agent.py diff --git a/examples/voice-pipeline-agent/multi_stage_agent.py b/examples/voice-pipeline-agent/multi_stage_agent.py new file mode 100644 index 000000000..1029bc7e1 --- /dev/null +++ b/examples/voice-pipeline-agent/multi_stage_agent.py @@ -0,0 +1,227 @@ +import json +import logging +from dataclasses import dataclass +from typing import Annotated + +from dotenv import load_dotenv +from livekit import rtc +from livekit.agents import ( + AutoSubscribe, + JobContext, + JobProcess, + WorkerOptions, + cli, + llm, +) +from livekit.agents.pipeline import AgentCallContext, VoicePipelineAgent +from livekit.agents.stt import SpeechData, SpeechEvent, SpeechEventType +from livekit.plugins import deepgram, openai, silero + +load_dotenv() + +logger = logging.getLogger("multi-stage-agent") +logger.setLevel(logging.INFO) + + +@dataclass +class AgentSpec: + instructions: str + fnc_ctx: llm.FunctionContext + + @classmethod + def create(cls, instructions: str, fncs: dict[str, llm.FunctionInfo]): + spec = cls(instructions=instructions, fnc_ctx=llm.FunctionContext()) + spec.fnc_ctx._fncs.update(fncs) + return spec + + +class RestaurantBot(llm.FunctionContext): + def __init__(self): + super().__init__() + + self._specs = { + "Greeter": AgentSpec.create( + instructions=( + "You are a professional restaurant receptionist handling incoming calls. " + "Warmly greet the caller and ask if they would like to place an order. " + "Available menu items: Pizza, Salad, Ice Cream, Coffee. " + "Guide the conversation as follows:\n" + "- If they want to place an order, transfer them to order taking\n" + "- If they have completed their order, transfer them to customer details\n" + "- For any other inquiries, assist them directly\n" + "Maintain a friendly and professional tone throughout the conversation." + "Use the functions to transfer the call to the next step." + ), + fncs={ + "transfer_to_ordering": self._fncs["transfer_to_ordering"], + "transfer_to_info_collection": self._fncs[ + "transfer_to_info_collection" + ], + }, + ), + "OrderTaking": AgentSpec.create( + instructions=( + "You are a professional order taker at a restaurant. " + "Guide the customer through their order with these steps:\n" + "1. Take their order selections one at a time from our menu: Pizza, Salad, Ice Cream, Coffee\n" + "2. Clarify any special requests or modifications\n" + "3. Repeat back the complete order to confirm accuracy\n" + "4. Once confirmed, transfer them back to the greeter\n" + "Be attentive and ensure order accuracy before proceeding." + ), + fncs={ + "take_order": self._fncs["take_order"], + "transfer_to_greeter": self._fncs["transfer_to_greeter"], + }, + ), + "CustomerDetails": AgentSpec.create( + instructions=( + "You are collecting essential customer information for their order. " + "Follow these steps carefully:\n" + "1. Ask for the customer's name and confirm the spelling\n" + "2. Request their phone number and verify it's correct\n" + "3. Repeat both pieces of information back to ensure accuracy\n" + "4. Once confirmed, transfer back to the greeter\n" + "Handle personal information professionally and courteously." + ), + fncs={ + "collect_name": self._fncs["collect_name"], + "collect_phone": self._fncs["collect_phone"], + "transfer_to_greeter": self._fncs["transfer_to_greeter"], + }, + ), + } + + self._cur_spec = self._specs["Greeter"] + + def _transfer_to_spec(self, spec_name: str, agent: VoicePipelineAgent) -> None: + self._cur_spec = self._specs[spec_name] + # TODO: update chat ctx for each spec + # agent._chat_ctx = self.get_chat_ctx(agent._chat_ctx) + logger.info(f"Transferring to {spec_name}") + + def get_chat_ctx(self, chat_ctx: llm.ChatContext | None = None) -> llm.ChatContext: + """Get the chat context for the current spec""" + new_chat_ctx = llm.ChatContext().append( + text=self._cur_spec.instructions, + role="system", + ) + if chat_ctx: + messages = chat_ctx.messages + if messages and messages[0].role == "system": + messages = messages[1:] + + # # Greeter has all the chat history, others have the last 6 messages + # if self._cur_spec != "Greeter": + # messages = messages[-6:] + new_chat_ctx.messages.extend(messages) + + return new_chat_ctx + + def before_llm_callback( + self, agent: VoicePipelineAgent, chat_ctx: llm.ChatContext + ) -> llm.LLMStream: + return agent.llm.chat( + chat_ctx=self.get_chat_ctx(chat_ctx), + fnc_ctx=self._cur_spec.fnc_ctx, + parallel_tool_calls=False, + ) + + @llm.ai_callable() + async def take_order( + self, + item: Annotated[str, llm.TypeInfo(description="The item added to the order")], + ): + """Called when the user orders a new item from our menu.""" + logger.info(f"Taking order for {item}") + return f"Received order for {item}" + + @llm.ai_callable() + async def collect_name( + self, name: Annotated[str, llm.TypeInfo(description="The customer's name")] + ): + """Called when the user provides their name.""" + logger.info(f"Collecting name: {name}") + return f"Please confirm with the customer that their name is {name}." + + @llm.ai_callable() + async def collect_phone( + self, + phone: Annotated[str, llm.TypeInfo(description="The customer's phone number")], + ): + """Called when the user provides their phone number.""" + logger.info(f"Collecting phone: {phone}") + return f"Please confirm with the customer that their phone number is {phone}." + + @llm.ai_callable() + async def transfer_to_ordering(self): + """Called to transfer the call to order taking.""" + call_ctx = AgentCallContext.get_current() + self._transfer_to_spec("OrderTaking", call_ctx.agent) + return "Transferred to order taking." + + @llm.ai_callable() + async def transfer_to_info_collection(self): + """Called to transfer the call to collect the customer's details.""" + call_ctx = AgentCallContext.get_current() + self._transfer_to_spec("CustomerDetails", call_ctx.agent) + return "Transferred to collecting customer details." + + @llm.ai_callable() + async def transfer_to_greeter(self): + """Called to transfer the call back to the greeter.""" + call_ctx = AgentCallContext.get_current() + self._transfer_to_spec("Greeter", call_ctx.agent) + return "Back to the greeter." + + +def prewarm_process(proc: JobProcess): + # preload silero VAD in memory to speed up session start + proc.userdata["vad"] = silero.VAD.load() + + +async def entrypoint(ctx: JobContext): + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + fnc_ctx = RestaurantBot() + initial_chat_ctx = fnc_ctx.get_chat_ctx() + + participant = await ctx.wait_for_participant() + agent = VoicePipelineAgent( + vad=ctx.proc.userdata["vad"], + stt=deepgram.STT(), + llm=openai.LLM(), + tts=openai.TTS(), + fnc_ctx=fnc_ctx, + chat_ctx=initial_chat_ctx, + before_llm_cb=fnc_ctx.before_llm_callback, + # preemptive_synthesis=True, + ) + + @ctx.room.on("data_received") + def on_data_received(packet: rtc.DataPacket): + if packet.topic == "lk-chat-topic": + data = json.loads(packet.data.decode("utf-8")) + logger.info(f"Text input received: {data}") + + agent._human_input.emit( + "final_transcript", + SpeechEvent( + type=SpeechEventType.END_OF_SPEECH, + alternatives=[SpeechData(language="en", text=data["message"])], + ), + ) + + # Start the assistant. This will automatically publish a microphone track and listen to the participant. + agent.start(ctx.room, participant) + await agent.say( + "Welcome to our restaurant! How may I assist you with your order today?" + ) + + +if __name__ == "__main__": + cli.run_app( + WorkerOptions( + entrypoint_fnc=entrypoint, + prewarm_fnc=prewarm_process, + ), + ) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index a08291ea4..0bdf5223d 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -920,7 +920,8 @@ async def _execute_function_calls() -> None: chat_ctx = call_ctx.chat_ctx.copy() chat_ctx.messages.extend(extra_tools_messages) chat_ctx.messages.extend(call_ctx.extra_chat_messages) - answer_llm_stream = self._llm.chat(chat_ctx=chat_ctx, fnc_ctx=self.fnc_ctx) + # answer_llm_stream = self._llm.chat(chat_ctx=chat_ctx, fnc_ctx=self.fnc_ctx) + answer_llm_stream = self._opts.before_llm_cb(self, chat_ctx) synthesis_handle = self._synthesize_agent_speech( new_speech_handle.id, answer_llm_stream From 9e1faa2939d010ed0bb063eae678c59277d836e3 Mon Sep 17 00:00:00 2001 From: Long Chen Date: Sun, 15 Dec 2024 18:37:24 +0800 Subject: [PATCH 02/13] update multi stage example with better transfer --- .../voice-pipeline-agent/multi_stage_agent.py | 193 ++++++++++-------- .../livekit/agents/pipeline/pipeline_agent.py | 7 +- 2 files changed, 115 insertions(+), 85 deletions(-) diff --git a/examples/voice-pipeline-agent/multi_stage_agent.py b/examples/voice-pipeline-agent/multi_stage_agent.py index 1029bc7e1..d687d51ba 100644 --- a/examples/voice-pipeline-agent/multi_stage_agent.py +++ b/examples/voice-pipeline-agent/multi_stage_agent.py @@ -1,7 +1,7 @@ import json import logging from dataclasses import dataclass -from typing import Annotated +from typing import Annotated, Callable from dotenv import load_dotenv from livekit import rtc @@ -25,26 +25,27 @@ @dataclass class AgentSpec: - instructions: str + chat_ctx: llm.ChatContext fnc_ctx: llm.FunctionContext @classmethod - def create(cls, instructions: str, fncs: dict[str, llm.FunctionInfo]): - spec = cls(instructions=instructions, fnc_ctx=llm.FunctionContext()) - spec.fnc_ctx._fncs.update(fncs) - return spec + def create(cls, instructions: str, fncs: list[Callable]): + chat_ctx = llm.ChatContext().append(text=instructions, role="system") + fnc_ctx = llm.FunctionContext() + for fnc in fncs: + fnc_ctx._register_ai_function(fnc) + return cls(chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) -class RestaurantBot(llm.FunctionContext): - def __init__(self): - super().__init__() - +class RestaurantBot: + def __init__(self, menu: str = "Pizza, Salad, Ice Cream, Coffee"): + self._menu = menu self._specs = { "Greeter": AgentSpec.create( instructions=( "You are a professional restaurant receptionist handling incoming calls. " "Warmly greet the caller and ask if they would like to place an order. " - "Available menu items: Pizza, Salad, Ice Cream, Coffee. " + f"Available menu items: {self._menu}. " "Guide the conversation as follows:\n" "- If they want to place an order, transfer them to order taking\n" "- If they have completed their order, transfer them to customer details\n" @@ -52,27 +53,25 @@ def __init__(self): "Maintain a friendly and professional tone throughout the conversation." "Use the functions to transfer the call to the next step." ), - fncs={ - "transfer_to_ordering": self._fncs["transfer_to_ordering"], - "transfer_to_info_collection": self._fncs[ - "transfer_to_info_collection" - ], - }, + fncs=[ + self.transfer_to_ordering, + self.transfer_to_info_collection, + ], ), "OrderTaking": AgentSpec.create( instructions=( "You are a professional order taker at a restaurant. " "Guide the customer through their order with these steps:\n" - "1. Take their order selections one at a time from our menu: Pizza, Salad, Ice Cream, Coffee\n" + f"1. Take their order selections one at a time from our menu: {self._menu}\n" "2. Clarify any special requests or modifications\n" "3. Repeat back the complete order to confirm accuracy\n" "4. Once confirmed, transfer them back to the greeter\n" "Be attentive and ensure order accuracy before proceeding." ), - fncs={ - "take_order": self._fncs["take_order"], - "transfer_to_greeter": self._fncs["transfer_to_greeter"], - }, + fncs=[ + self.update_order, + self.transfer_to_greeter, + ], ), "CustomerDetails": AgentSpec.create( instructions=( @@ -84,65 +83,64 @@ def __init__(self): "4. Once confirmed, transfer back to the greeter\n" "Handle personal information professionally and courteously." ), - fncs={ - "collect_name": self._fncs["collect_name"], - "collect_phone": self._fncs["collect_phone"], - "transfer_to_greeter": self._fncs["transfer_to_greeter"], - }, + fncs=[ + self.collect_name, + self.collect_phone, + self.transfer_to_greeter, + ], ), } - self._cur_spec = self._specs["Greeter"] + self._cur_spec = "Greeter" + self._order: str | None = None + self._customer_name: str | None = None + self._customer_phone: str | None = None - def _transfer_to_spec(self, spec_name: str, agent: VoicePipelineAgent) -> None: - self._cur_spec = self._specs[spec_name] - # TODO: update chat ctx for each spec - # agent._chat_ctx = self.get_chat_ctx(agent._chat_ctx) - logger.info(f"Transferring to {spec_name}") + @property + def spec(self) -> AgentSpec: + return self._specs[self._cur_spec] - def get_chat_ctx(self, chat_ctx: llm.ChatContext | None = None) -> llm.ChatContext: - """Get the chat context for the current spec""" - new_chat_ctx = llm.ChatContext().append( - text=self._cur_spec.instructions, - role="system", - ) - if chat_ctx: - messages = chat_ctx.messages - if messages and messages[0].role == "system": - messages = messages[1:] - - # # Greeter has all the chat history, others have the last 6 messages - # if self._cur_spec != "Greeter": - # messages = messages[-6:] - new_chat_ctx.messages.extend(messages) - - return new_chat_ctx - - def before_llm_callback( - self, agent: VoicePipelineAgent, chat_ctx: llm.ChatContext - ) -> llm.LLMStream: - return agent.llm.chat( - chat_ctx=self.get_chat_ctx(chat_ctx), - fnc_ctx=self._cur_spec.fnc_ctx, - parallel_tool_calls=False, - ) + def _transfer_to_spec(self, spec_name: str, call_ctx: AgentCallContext) -> None: + agent = call_ctx.agent + + keep_last_n = 6 + prev_messages = agent.chat_ctx.messages.copy() + while prev_messages and prev_messages[0].role in ["system", "tool"]: + prev_messages.pop(0) + prev_messages = prev_messages[-keep_last_n:] + + self._cur_spec = spec_name + agent._fnc_ctx = self.spec.fnc_ctx + agent._chat_ctx = self.spec.chat_ctx + agent._chat_ctx.messages.extend(prev_messages) + + # use the new chat_ctx in the call_ctx + call_ctx.chat_ctx.messages = agent.chat_ctx.messages.copy() + logger.info(f"Transferred to {spec_name}") @llm.ai_callable() - async def take_order( + async def update_order( self, - item: Annotated[str, llm.TypeInfo(description="The item added to the order")], + item: Annotated[ + str, + llm.TypeInfo( + description="The items of the full order, separated by commas" + ), + ], ): - """Called when the user orders a new item from our menu.""" - logger.info(f"Taking order for {item}") - return f"Received order for {item}" + """Called when the user updates their order.""" + self._order = item + logger.info("Updated order", extra={"order": item}) + return f"Updated order to {item}" @llm.ai_callable() async def collect_name( self, name: Annotated[str, llm.TypeInfo(description="The customer's name")] ): """Called when the user provides their name.""" - logger.info(f"Collecting name: {name}") - return f"Please confirm with the customer that their name is {name}." + self._customer_name = name + logger.info("Collected name", extra={"customer_name": name}) + return f"The name is updated to {name}" @llm.ai_callable() async def collect_phone( @@ -150,29 +148,58 @@ async def collect_phone( phone: Annotated[str, llm.TypeInfo(description="The customer's phone number")], ): """Called when the user provides their phone number.""" - logger.info(f"Collecting phone: {phone}") - return f"Please confirm with the customer that their phone number is {phone}." + # validate phone number + phone = phone.strip().replace("-", "") + if not phone.isdigit() or len(phone) != 10: + return "The phone number is not valid, please try again." + + self._customer_phone = phone + logger.info("Collected phone", extra={"customer_phone": phone}) + return f"The phone number is updated to {phone}" @llm.ai_callable() async def transfer_to_ordering(self): """Called to transfer the call to order taking.""" call_ctx = AgentCallContext.get_current() - self._transfer_to_spec("OrderTaking", call_ctx.agent) - return "Transferred to order taking." + self._transfer_to_spec("OrderTaking", call_ctx) + return f"Transferred to order taking, the current order is {self._order}" @llm.ai_callable() async def transfer_to_info_collection(self): """Called to transfer the call to collect the customer's details.""" call_ctx = AgentCallContext.get_current() - self._transfer_to_spec("CustomerDetails", call_ctx.agent) - return "Transferred to collecting customer details." + self._transfer_to_spec("CustomerDetails", call_ctx) + return ( + f"Transferred to collecting customer details, " + f"the current collected name is {self._customer_name} " + f"and phone number is {self._customer_phone}" + ) @llm.ai_callable() - async def transfer_to_greeter(self): + async def transfer_to_greeter( + self, + # summary: Annotated[ + # str, + # llm.TypeInfo( + # description="The summary of conversations in the current stage" + # ), + # ], + ): """Called to transfer the call back to the greeter.""" + # message = f"Back to the greeter from {self._cur_spec}, the summary of conversations is {summary}" + message = f"Back to the greeter from {self._cur_spec}. " + if self._cur_spec == "OrderTaking": + message += f"The current order is {self._order}" + elif self._cur_spec == "CustomerDetails": + message += ( + f"The current collected name is {self._customer_name} " + f"and phone number is {self._customer_phone}" + ) + logger.info("Back to greeter", extra={"summary": message}) + call_ctx = AgentCallContext.get_current() - self._transfer_to_spec("Greeter", call_ctx.agent) - return "Back to the greeter." + self._transfer_to_spec("Greeter", call_ctx) + return message def prewarm_process(proc: JobProcess): @@ -182,8 +209,9 @@ def prewarm_process(proc: JobProcess): async def entrypoint(ctx: JobContext): await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) - fnc_ctx = RestaurantBot() - initial_chat_ctx = fnc_ctx.get_chat_ctx() + + menu = "Pizza, Salad, Ice Cream, Coffee" + multi_stage_ctx = RestaurantBot(menu) participant = await ctx.wait_for_participant() agent = VoicePipelineAgent( @@ -191,17 +219,16 @@ async def entrypoint(ctx: JobContext): stt=deepgram.STT(), llm=openai.LLM(), tts=openai.TTS(), - fnc_ctx=fnc_ctx, - chat_ctx=initial_chat_ctx, - before_llm_cb=fnc_ctx.before_llm_callback, - # preemptive_synthesis=True, + fnc_ctx=multi_stage_ctx.spec.fnc_ctx, + chat_ctx=multi_stage_ctx.spec.chat_ctx, + max_nested_fnc_calls=2, # may call functions in the transition function ) @ctx.room.on("data_received") def on_data_received(packet: rtc.DataPacket): if packet.topic == "lk-chat-topic": data = json.loads(packet.data.decode("utf-8")) - logger.info(f"Text input received: {data}") + logger.info(f"Text input received: {data['message']}") agent._human_input.emit( "final_transcript", @@ -214,7 +241,7 @@ def on_data_received(packet: rtc.DataPacket): # Start the assistant. This will automatically publish a microphone track and listen to the participant. agent.start(ctx.room, participant) await agent.say( - "Welcome to our restaurant! How may I assist you with your order today?" + f"Welcome to our restaurant! We offer {menu}. How may I assist you today?" ) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 0bdf5223d..64ad68757 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -845,6 +845,10 @@ async def _execute_function_calls() -> None: extra={ "speech_id": speech_handle.id, "fnc_nested_depth": speech_handle.fnc_nested_depth, + "fnc_names": [ + fnc.function_info.name + for fnc in speech_handle.source.function_calls + ], }, ) return @@ -920,8 +924,7 @@ async def _execute_function_calls() -> None: chat_ctx = call_ctx.chat_ctx.copy() chat_ctx.messages.extend(extra_tools_messages) chat_ctx.messages.extend(call_ctx.extra_chat_messages) - # answer_llm_stream = self._llm.chat(chat_ctx=chat_ctx, fnc_ctx=self.fnc_ctx) - answer_llm_stream = self._opts.before_llm_cb(self, chat_ctx) + answer_llm_stream = self._llm.chat(chat_ctx=chat_ctx, fnc_ctx=self.fnc_ctx) synthesis_handle = self._synthesize_agent_speech( new_speech_handle.id, answer_llm_stream From 50f9123e39f457760b3b411a19235d5b1c135bab Mon Sep 17 00:00:00 2001 From: Long Chen Date: Sun, 15 Dec 2024 18:50:32 +0800 Subject: [PATCH 03/13] fix: add filler messages and fix transfer_to_spec --- .../voice-pipeline-agent/multi_stage_agent.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/voice-pipeline-agent/multi_stage_agent.py b/examples/voice-pipeline-agent/multi_stage_agent.py index d687d51ba..e1f67fcf8 100644 --- a/examples/voice-pipeline-agent/multi_stage_agent.py +++ b/examples/voice-pipeline-agent/multi_stage_agent.py @@ -103,11 +103,11 @@ def spec(self) -> AgentSpec: def _transfer_to_spec(self, spec_name: str, call_ctx: AgentCallContext) -> None: agent = call_ctx.agent + # keep the last n messages for the next stage keep_last_n = 6 - prev_messages = agent.chat_ctx.messages.copy() + prev_messages = agent.chat_ctx.messages.copy()[-keep_last_n:] while prev_messages and prev_messages[0].role in ["system", "tool"]: prev_messages.pop(0) - prev_messages = prev_messages[-keep_last_n:] self._cur_spec = spec_name agent._fnc_ctx = self.spec.fnc_ctx @@ -148,10 +148,10 @@ async def collect_phone( phone: Annotated[str, llm.TypeInfo(description="The customer's phone number")], ): """Called when the user provides their phone number.""" - # validate phone number + # validate phone number (optional) phone = phone.strip().replace("-", "") if not phone.isdigit() or len(phone) != 10: - return "The phone number is not valid, please try again." + return "The phone number is invalid, it should be a 10-digit number, please try again." self._customer_phone = phone logger.info("Collected phone", extra={"customer_phone": phone}) @@ -161,14 +161,23 @@ async def collect_phone( async def transfer_to_ordering(self): """Called to transfer the call to order taking.""" call_ctx = AgentCallContext.get_current() + self._transfer_to_spec("OrderTaking", call_ctx) + await call_ctx.agent.say( + "I'll transfer you to our order taker who will help you with your selections." + ) + return f"Transferred to order taking, the current order is {self._order}" @llm.ai_callable() async def transfer_to_info_collection(self): """Called to transfer the call to collect the customer's details.""" call_ctx = AgentCallContext.get_current() + self._transfer_to_spec("CustomerDetails", call_ctx) + await call_ctx.agent.say( + "Great! I'll collect your contact information now." + ) return ( f"Transferred to collecting customer details, " f"the current collected name is {self._customer_name} " From 98144ecabf97eb4f4d94fb0203e4e23a6ec3368a Mon Sep 17 00:00:00 2001 From: Long Chen Date: Sun, 15 Dec 2024 19:16:53 +0800 Subject: [PATCH 04/13] log the chat ctx to file --- .../voice-pipeline-agent/multi_stage_agent.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/examples/voice-pipeline-agent/multi_stage_agent.py b/examples/voice-pipeline-agent/multi_stage_agent.py index e1f67fcf8..b52a2fc86 100644 --- a/examples/voice-pipeline-agent/multi_stage_agent.py +++ b/examples/voice-pipeline-agent/multi_stage_agent.py @@ -175,9 +175,7 @@ async def transfer_to_info_collection(self): call_ctx = AgentCallContext.get_current() self._transfer_to_spec("CustomerDetails", call_ctx) - await call_ctx.agent.say( - "Great! I'll collect your contact information now." - ) + await call_ctx.agent.say("Great! I'll collect your contact information now.") return ( f"Transferred to collecting customer details, " f"the current collected name is {self._customer_name} " @@ -219,6 +217,7 @@ def prewarm_process(proc: JobProcess): async def entrypoint(ctx: JobContext): await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + chat_log_file = "multi_stage_chat_log.txt" menu = "Pizza, Salad, Ice Cream, Coffee" multi_stage_ctx = RestaurantBot(menu) @@ -233,6 +232,7 @@ async def entrypoint(ctx: JobContext): max_nested_fnc_calls=2, # may call functions in the transition function ) + # For testing with text input @ctx.room.on("data_received") def on_data_received(packet: rtc.DataPacket): if packet.topic == "lk-chat-topic": @@ -247,6 +247,25 @@ def on_data_received(packet: rtc.DataPacket): ), ) + @agent.on("user_speech_committed") + @agent.on("agent_speech_interrupted") + @agent.on("agent_speech_committed") + def on_speech_committed(message: llm.ChatMessage): + with open(chat_log_file, "a") as f: + f.write(f"{message.role}: {message.content}\n") + + @agent.on("function_calls_collected") + def on_function_calls_collected(calls: list[llm.FunctionCallInfo]): + fnc_infos = [{fnc.function_info.name: fnc.arguments} for fnc in calls] + with open(chat_log_file, "a") as f: + f.write(f"fnc_calls_collected: {fnc_infos}\n") + + @agent.on("function_calls_finished") + def on_function_calls_finished(calls: list[llm.CalledFunction]): + called_infos = [{fnc.call_info.function_info.name: fnc.result} for fnc in calls] + with open(chat_log_file, "a") as f: + f.write(f"fnc_calls_finished: {called_infos}\n") + # Start the assistant. This will automatically publish a microphone track and listen to the participant. agent.start(ctx.room, participant) await agent.say( From a4712fd5ce70bfb66ebb8dcd94f1e7cb289ebfb1 Mon Sep 17 00:00:00 2001 From: Long Chen Date: Sat, 21 Dec 2024 13:04:03 +0800 Subject: [PATCH 05/13] add agent task --- .../livekit/agents/llm/chat_context.py | 2 +- .../livekit/agents/llm/function_context.py | 54 ++++++++-- .../livekit/agents/pipeline/agent_output.py | 10 +- .../livekit/agents/pipeline/agent_task.py | 89 +++++++++++++++++ .../livekit/agents/pipeline/pipeline_agent.py | 98 +++++++++++++++---- 5 files changed, 218 insertions(+), 35 deletions(-) create mode 100644 livekit-agents/livekit/agents/pipeline/agent_task.py diff --git a/livekit-agents/livekit/agents/llm/chat_context.py b/livekit-agents/livekit/agents/llm/chat_context.py index ccde86bba..4fe3ea5b4 100644 --- a/livekit-agents/livekit/agents/llm/chat_context.py +++ b/livekit-agents/livekit/agents/llm/chat_context.py @@ -113,7 +113,7 @@ def create_tool_from_called_function( tool_exception: Exception | None = None try: - content = called_function.task.result() + content = called_function.get_content() except BaseException as e: if isinstance(e, Exception): tool_exception = e diff --git a/livekit-agents/livekit/agents/llm/function_context.py b/livekit-agents/livekit/agents/llm/function_context.py index 4470492fe..cf715d1df 100644 --- a/livekit-agents/livekit/agents/llm/function_context.py +++ b/livekit-agents/livekit/agents/llm/function_context.py @@ -22,10 +22,13 @@ import types import typing from dataclasses import dataclass -from typing import Any, Callable, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple from ..log import logger +if TYPE_CHECKING: + from ..pipeline.agent_task import AgentTask + class _UseDocMarker: pass @@ -101,6 +104,28 @@ class CalledFunction: result: Any | None = None exception: BaseException | None = None + def get_agent_task(self) -> "AgentTask" | None: + assert self.task.done() + + if isinstance(self.result, tuple): + assert len(self.result) == 2 and isinstance(self.result[0], AgentTask) + return self.result[0] + elif isinstance(self.result, AgentTask): + return self.result + return None + + def get_content(self) -> Any | None: + assert self.task.done() + + if self.exception: + return f"Error: {self.exception}" + if isinstance(self.result, tuple): + assert len(self.result) == 2 and isinstance(self.result[1], str) + return self.result[1] + elif not isinstance(self.result, AgentTask): + return self.result + return None + def ai_callable( *, @@ -136,16 +161,13 @@ def deco(f): return deco - def _register_ai_function(self, fnc: Callable) -> None: + @staticmethod + def _callable_to_fnc_info(fnc: Callable) -> FunctionInfo | None: if not hasattr(fnc, METADATA_ATTR): - logger.warning(f"function {fnc.__name__} does not have ai metadata") - return + return None metadata: _AIFncMetadata = getattr(fnc, METADATA_ATTR) fnc_name = metadata.name - if fnc_name in self._fncs: - raise ValueError(f"duplicate ai_callable name: {fnc_name}") - sig = inspect.signature(fnc) # get_type_hints with include_extra=True is needed when using Annotated @@ -190,7 +212,7 @@ def _register_ai_function(self, fnc: Callable) -> None: choices=choices, ) - self._fncs[metadata.name] = FunctionInfo( + return FunctionInfo( name=metadata.name, description=metadata.description, auto_retry=metadata.auto_retry, @@ -198,10 +220,26 @@ def _register_ai_function(self, fnc: Callable) -> None: arguments=args, ) + def _register_ai_function(self, fnc: Callable) -> None: + fnc_info = self._callable_to_fnc_info(fnc) + if not fnc_info: + logger.warning(f"function {fnc.__name__} does not have ai metadata") + return + + if fnc_info.name in self._fncs: + raise ValueError(f"duplicate ai_callable name: {fnc_info.name}") + + self._fncs[fnc_info.name] = fnc_info + @property def ai_functions(self) -> dict[str, FunctionInfo]: return self._fncs + def copy(self) -> "FunctionContext": + new_fnc_ctx = FunctionContext() + new_fnc_ctx._fncs.update(self._fncs) + return new_fnc_ctx + @dataclass(frozen=True) class _AIFncMetadata: diff --git a/livekit-agents/livekit/agents/pipeline/agent_output.py b/livekit-agents/livekit/agents/pipeline/agent_output.py index 14a836ef7..2deb80466 100644 --- a/livekit-agents/livekit/agents/pipeline/agent_output.py +++ b/livekit-agents/livekit/agents/pipeline/agent_output.py @@ -6,7 +6,7 @@ from livekit import rtc -from .. import llm, tokenize, utils +from .. import tokenize, utils from .. import transcription as agent_transcription from .. import tts as text_to_speech from .agent_playout import AgentPlayout, PlayoutHandle @@ -96,15 +96,9 @@ def __init__( *, room: rtc.Room, agent_playout: AgentPlayout, - llm: llm.LLM, tts: text_to_speech.TTS, ) -> None: - self._room, self._agent_playout, self._llm, self._tts = ( - room, - agent_playout, - llm, - tts, - ) + self._room, self._agent_playout, self._tts = room, agent_playout, tts self._tasks = set[asyncio.Task[Any]]() @property diff --git a/livekit-agents/livekit/agents/pipeline/agent_task.py b/livekit-agents/livekit/agents/pipeline/agent_task.py new file mode 100644 index 000000000..10818a81a --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/agent_task.py @@ -0,0 +1,89 @@ +from ..llm import LLM, ChatContext, FunctionContext +from ..llm.function_context import ( + METADATA_ATTR, + USE_DOCSTRING, + FunctionInfo, + ai_callable, +) +from ..stt import STT + +# class TaskContext: +# def __init__(self, assistant: "VoicePipelineAgent"): +# self._assistant = assistant + +# @property +# def agent(self) -> "VoicePipelineAgent": +# return self._assistant + +# @property +# def user_data(self) -> dict[str, Any]: +# return self._assistant.user_data + +# @property +# def current_task(self) -> "AgentTask" | None: +# return self._assistant._current_task + +# @property +# def room(self) -> rtc.Room: +# if not hasattr(self._assistant, "_room"): +# raise ValueError("VoicePipelineAgent is not started") +# return self._assistant._room + + +class AgentTask: + def __init__( + self, + instructions: str | None = None, + fnc_ctx: FunctionContext | None = None, + llm: LLM | None = None, + stt: STT | None = None, + name: str | None = None, + ) -> None: + self._chat_ctx = ChatContext() + if instructions: + self._chat_ctx.append(text=instructions, role="system") + self._fnc_ctx = fnc_ctx + self._llm = llm + # TODO: support customized llm and stt + self._stt = stt + + self._task_name = name or self.__class__.__name__ + + # enter method for transition + enter_fnc = self.enter + if not hasattr(enter_fnc, METADATA_ATTR): + enter_fnc = ai_callable( + name=f"enter_{self._task_name}", description=USE_DOCSTRING + )(self.enter) + + self._enter_fnc_info = FunctionContext._callable_to_fnc_info(enter_fnc) + + def can_enter(self) -> bool: + return True + + def enter(self) -> "AgentTask" | tuple["AgentTask", str]: + return self + + @property + def task_name(self) -> str: + return self._task_name + + @property + def enter_fnc_info(self) -> FunctionInfo: + return self._enter_fnc_info + + @property + def chat_ctx(self) -> ChatContext: + return self._chat_ctx + + @property + def fnc_ctx(self) -> FunctionContext | None: + return self._fnc_ctx + + @property + def llm(self) -> LLM | None: + return self._llm + + @property + def stt(self) -> STT | None: + return self._stt diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 64ad68757..7ac3efc31 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -19,10 +19,18 @@ from livekit import rtc from .. import metrics, stt, tokenize, tts, utils, vad -from ..llm import LLM, ChatContext, ChatMessage, FunctionContext, LLMStream +from ..llm import ( + LLM, + CalledFunction, + ChatContext, + ChatMessage, + FunctionContext, + LLMStream, +) from ..types import ATTRIBUTE_AGENT_STATE, AgentState from .agent_output import AgentOutput, SpeechSource, SynthesisHandle from .agent_playout import AgentPlayout +from .agent_task import AgentTask from .human_input import HumanInput from .log import logger from .plotter import AssistantPlotter @@ -184,8 +192,10 @@ def __init__( llm: LLM, tts: tts.TTS, turn_detector: _TurnDetector | None = None, - chat_ctx: ChatContext | None = None, - fnc_ctx: FunctionContext | None = None, + # chat_ctx: ChatContext | None = None, + # fnc_ctx: FunctionContext | None = None, + initial_task: AgentTask | None = None, + available_tasks: list[AgentTask] | None = None, allow_interruptions: bool = True, interrupt_speech_duration: float = 0.5, interrupt_min_words: int = 0, @@ -275,8 +285,8 @@ def __init__( self._stt, self._vad, self._llm, self._tts = stt, vad, llm, tts self._turn_detector = turn_detector - self._chat_ctx = chat_ctx or ChatContext() - self._fnc_ctx = fnc_ctx + # self._chat_ctx = chat_ctx or ChatContext() + # self._fnc_ctx = fnc_ctx self._started, self._closed = False, False self._human_input: HumanInput | None = None @@ -306,13 +316,51 @@ def __init__( self._last_final_transcript_time: float | None = None self._last_speech_time: float | None = None + # agent tasks + self._user_data: dict[str, Any] = {} + self._current_agent_task = initial_task or AgentTask() + self._agent_tasks: list[AgentTask] = available_tasks or [] + + @property + def user_data(self) -> dict[str, Any]: + return self._user_data + + @property + def agent_tasks(self) -> list[AgentTask]: + return self._agent_tasks + + @property + def current_agent_task(self) -> AgentTask: + return self._current_agent_task + + @current_agent_task.setter + def current_agent_task(self, task: AgentTask | None) -> None: + self._current_agent_task = task + @property def fnc_ctx(self) -> FunctionContext | None: - return self._fnc_ctx + available_tasks = [task for task in self._agent_tasks if task.can_enter()] + if not available_tasks: + # no transition available, return the current function context + return self._current_agent_task.fnc_ctx + + new_fnc_ctx = ( + self._current_agent_task.fnc_ctx.copy() + if self._current_agent_task.fnc_ctx + else FunctionContext() + ) + for task in available_tasks: + if task.enter_fnc_info.name in new_fnc_ctx._fncs: + raise ValueError( + f"duplicate ai_callable name: {task.enter_fnc_info.name}" + ) + new_fnc_ctx._fncs[task.enter_fnc_info.name] = task.enter_fnc_info + + return new_fnc_ctx - @fnc_ctx.setter - def fnc_ctx(self, fnc_ctx: FunctionContext | None) -> None: - self._fnc_ctx = fnc_ctx + @property + def _chat_ctx(self) -> ChatContext: + return self._current_agent_task.chat_ctx @property def chat_ctx(self) -> ChatContext: @@ -320,7 +368,7 @@ def chat_ctx(self) -> ChatContext: @property def llm(self) -> LLM: - return self._llm + return self._current_agent_task.llm or self._llm @property def tts(self) -> tts.TTS: @@ -383,6 +431,10 @@ def _on_llm_metrics(llm_metrics: metrics.LLMMetrics) -> None: ), ) + for agent_task in self._agent_tasks: + if agent_task.llm: + agent_task.llm.on("metrics_collected", _on_llm_metrics) + @self._vad.on("metrics_collected") def _on_vad_metrics(vad_metrics: vad.VADMetrics) -> None: self.emit( @@ -631,10 +683,7 @@ async def _main_task(self) -> None: agent_playout = AgentPlayout(audio_source=audio_source) self._agent_output = AgentOutput( - room=self._room, - agent_playout=agent_playout, - llm=self._llm, - tts=self._tts, + room=self._room, agent_playout=agent_playout, tts=self._tts ) def _on_playout_started() -> None: @@ -868,7 +917,7 @@ async def _execute_function_calls() -> None: self.emit("function_calls_collected", new_function_calls) - called_fncs = [] + called_fncs: list[CalledFunction] = [] for fnc in new_function_calls: called_fnc = fnc.execute() called_fncs.append(called_fnc) @@ -893,12 +942,25 @@ async def _execute_function_calls() -> None: tool_calls_info = [] tool_calls_results = [] - + tool_calls_chat_ctx = call_ctx.chat_ctx for called_fnc in called_fncs: # ignore the function calls that returns None if called_fnc.result is None and called_fnc.exception is None: continue + new_task = called_fnc.get_agent_task() + if new_task: + logger.debug( + "switching to next agent task", + extra={ + "current_task": self.current_agent_task.task_name, + "new_task": new_task.task_name, + }, + ) + self.current_agent_task = new_task + # use the new chat ctx for the next task + tool_calls_chat_ctx = self._chat_ctx + tool_calls_info.append(called_fnc.call_info) tool_calls_results.append( ChatMessage.create_tool_from_called_function(called_fnc) @@ -921,10 +983,10 @@ async def _execute_function_calls() -> None: ) # synthesize the tool speech with the chat ctx from llm_stream - chat_ctx = call_ctx.chat_ctx.copy() + chat_ctx = tool_calls_chat_ctx.copy() chat_ctx.messages.extend(extra_tools_messages) chat_ctx.messages.extend(call_ctx.extra_chat_messages) - answer_llm_stream = self._llm.chat(chat_ctx=chat_ctx, fnc_ctx=self.fnc_ctx) + answer_llm_stream = self.llm.chat(chat_ctx=chat_ctx, fnc_ctx=self.fnc_ctx) synthesis_handle = self._synthesize_agent_speech( new_speech_handle.id, answer_llm_stream From fac0f53c81def448c9d9d97f9ffeafcaad299f28 Mon Sep 17 00:00:00 2001 From: Long Chen Date: Sat, 21 Dec 2024 18:31:58 +0800 Subject: [PATCH 06/13] add a new example for agent task --- .../multi_stage_agent2.py | 252 ++++++++++++++++++ .../livekit/agents/llm/function_context.py | 8 +- .../livekit/agents/pipeline/agent_task.py | 84 +++--- .../livekit/agents/pipeline/pipeline_agent.py | 8 +- 4 files changed, 292 insertions(+), 60 deletions(-) create mode 100644 examples/voice-pipeline-agent/multi_stage_agent2.py diff --git a/examples/voice-pipeline-agent/multi_stage_agent2.py b/examples/voice-pipeline-agent/multi_stage_agent2.py new file mode 100644 index 000000000..a503f6d6a --- /dev/null +++ b/examples/voice-pipeline-agent/multi_stage_agent2.py @@ -0,0 +1,252 @@ +import json +import logging +from typing import Annotated, Self + +from dotenv import load_dotenv +from livekit import rtc +from livekit.agents import ( + AutoSubscribe, + JobContext, + JobProcess, + WorkerOptions, + cli, + llm, +) +from livekit.agents.pipeline import AgentCallContext, VoicePipelineAgent +from livekit.agents.pipeline.agent_task import AgentTask +from livekit.agents.stt import SpeechData, SpeechEvent, SpeechEventType +from livekit.plugins import deepgram, openai, silero + +load_dotenv() + +logger = logging.getLogger("multi-stage-agent") +logger.setLevel(logging.INFO) + + +def get_last_n_messages( + messages: list[llm.ChatMessage], n: int +) -> list[llm.ChatMessage]: + collected_messages = messages.copy()[-n:] + while collected_messages and collected_messages[0].role in ["system", "tool"]: + collected_messages.pop(0) + return collected_messages + + +def _transfer_to(task: AgentTask, message: str | None = None) -> tuple[AgentTask, str]: + agent = AgentCallContext.get_current().agent + + # keep the last n messages for the next stage + keep_last_n = 6 + task.chat_ctx.messages.extend( + get_last_n_messages(agent.chat_ctx.messages, keep_last_n) + ) + + message = ( + message or f"Transferred from {agent.current_agent_task.name} to {task.name}" + ) + logger.info(message) + return task, message + + +class Greeter(AgentTask): + def __init__(self, menu: str = "Pizza, Salad, Ice Cream, Coffee"): + super().__init__( + instructions=( + "You are a professional restaurant receptionist handling incoming calls. " + "Warmly greet the caller and ask if they would like to place an order. " + f"Available menu items: {menu}. " + "Guide the conversation as follows:\n" + "- If they want to place an order, transfer them to order taking\n" + "- If they have completed their order, transfer them to customer details\n" + "- For any other inquiries, assist them directly\n" + "Maintain a friendly and professional tone throughout the conversation." + "Use the functions to transfer the call to the next step." + ), + ) + + def can_enter(self, agent: "VoicePipelineAgent") -> bool: + return True + + @llm.ai_callable(name="enter_greeter") + async def enter(self) -> tuple[Self, str]: + """Called to transfer to the greeter.""" + + agent = AgentCallContext.get_current().agent + curr_task = agent.current_agent_task + + # return the collected information to the greeter + message = f"Transferred from {curr_task.name} to {self.name}. " + if isinstance(curr_task, OrderTaking): + message += f"The current order is {agent.user_data.get('order', 'empty')}" + elif isinstance(curr_task, CustomerDetails): + message += ( + f"The customer name is {agent.user_data.get('customer_name', 'unknown')}, " + f"phone number is {agent.user_data.get('customer_phone', 'unknown')}" + ) + + return _transfer_to(self, message) + + +class OrderTaking(AgentTask): + def __init__(self, menu: str = "Pizza, Salad, Ice Cream, Coffee"): + super().__init__( + instructions=( + "You are a professional order taker at a restaurant. " + "Guide the customer through their order with these steps:\n" + f"1. Take their order selections one at a time from our menu: {menu}\n" + "2. Clarify any special requests or modifications\n" + "3. Repeat back the complete order to confirm accuracy\n" + "4. Once confirmed, transfer them back to the greeter\n" + "Be attentive and ensure order accuracy before proceeding." + ), + functions=[self.update_order], + ) + + def can_enter(self, agent: "VoicePipelineAgent") -> bool: + checked_out = agent.user_data.get("checked_out", False) + return not checked_out + + @llm.ai_callable(name="enter_order_taking") + async def enter(self) -> tuple[Self, str]: + """Called to transfer to the order taking.""" + return _transfer_to(self) + + @llm.ai_callable() + async def update_order( + self, + item: Annotated[ + str, + llm.TypeInfo( + description="The items of the full order, separated by commas" + ), + ], + ) -> str: + """Called when the user updates their order.""" + + agent = AgentCallContext.get_current().agent + agent.user_data["order"] = item + + logger.info("Updated order", extra={"order": item}) + return f"Updated order to {item}" + + +class CustomerDetails(AgentTask): + def __init__(self): + super().__init__( + instructions=( + "You are collecting essential customer information for their order. " + "Follow these steps carefully:\n" + "1. Ask for the customer's name and confirm the spelling\n" + "2. Request their phone number and verify it's correct\n" + "3. Repeat both pieces of information back to ensure accuracy\n" + "4. Once confirmed, transfer back to the greeter\n" + "Handle personal information professionally and courteously." + ), + functions=[self.collect_name, self.collect_phone], + ) + + def can_enter(self, agent: "VoicePipelineAgent") -> bool: + checked_out = agent.user_data.get("checked_out", False) + order = agent.user_data.get("order", None) + return order and not checked_out + + @llm.ai_callable(name="enter_customer_details") + async def enter(self) -> tuple[Self, str]: + """Called to transfer to the customer details.""" + return _transfer_to(self) + + @llm.ai_callable() + async def collect_name( + self, name: Annotated[str, llm.TypeInfo(description="The customer's name")] + ) -> str: + """Called when the user provides their name.""" + agent = AgentCallContext.get_current().agent + agent.user_data["customer_name"] = name + + logger.info("Collected name", extra={"customer_name": name}) + return f"The name is updated to {name}" + + @llm.ai_callable() + async def collect_phone( + self, + phone: Annotated[str, llm.TypeInfo(description="The customer's phone number")], + ) -> str: + """Called when the user provides their phone number.""" + agent = AgentCallContext.get_current().agent + agent.user_data["customer_phone"] = phone + + logger.info("Collected phone", extra={"customer_phone": phone}) + return f"The phone number is updated to {phone}" + + +def prewarm_process(proc: JobProcess): + # preload silero VAD in memory to speed up session start + proc.userdata["vad"] = silero.VAD.load() + + +async def entrypoint(ctx: JobContext): + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + + chat_log_file = "multi_stage_chat_log.txt" + menu = "Pizza, Salad, Ice Cream, Coffee" + agent_tasks = [Greeter(menu), OrderTaking(menu), CustomerDetails()] + + participant = await ctx.wait_for_participant() + agent = VoicePipelineAgent( + vad=ctx.proc.userdata["vad"], + stt=deepgram.STT(), + llm=openai.LLM(), + tts=openai.TTS(), + initial_task=agent_tasks[0], + available_tasks=agent_tasks, + max_nested_fnc_calls=2, # may call functions in the transition function + ) + + # For testing with text input + @ctx.room.on("data_received") + def on_data_received(packet: rtc.DataPacket): + if packet.topic == "lk-chat-topic": + data = json.loads(packet.data.decode("utf-8")) + logger.debug("Text input received", extra={"message": data["message"]}) + + agent._human_input.emit( + "final_transcript", + SpeechEvent( + type=SpeechEventType.END_OF_SPEECH, + alternatives=[SpeechData(language="en", text=data["message"])], + ), + ) + + @agent.on("user_speech_committed") + @agent.on("agent_speech_interrupted") + @agent.on("agent_speech_committed") + def on_speech_committed(message: llm.ChatMessage): + with open(chat_log_file, "a") as f: + f.write(f"{message.role}: {message.content}\n") + + @agent.on("function_calls_collected") + def on_function_calls_collected(calls: list[llm.FunctionCallInfo]): + fnc_infos = [{fnc.function_info.name: fnc.arguments} for fnc in calls] + with open(chat_log_file, "a") as f: + f.write(f"fnc_calls_collected: {fnc_infos}\n") + + @agent.on("function_calls_finished") + def on_function_calls_finished(calls: list[llm.CalledFunction]): + called_infos = [{fnc.call_info.function_info.name: fnc.result} for fnc in calls] + with open(chat_log_file, "a") as f: + f.write(f"fnc_calls_finished: {called_infos}\n") + + # Start the assistant. This will automatically publish a microphone track and listen to the participant. + agent.start(ctx.room, participant) + await agent.say( + f"Welcome to our restaurant! We offer {menu}. How may I assist you today?" + ) + + +if __name__ == "__main__": + cli.run_app( + WorkerOptions( + entrypoint_fnc=entrypoint, + prewarm_fnc=prewarm_process, + ), + ) diff --git a/livekit-agents/livekit/agents/llm/function_context.py b/livekit-agents/livekit/agents/llm/function_context.py index cf715d1df..d523b7900 100644 --- a/livekit-agents/livekit/agents/llm/function_context.py +++ b/livekit-agents/livekit/agents/llm/function_context.py @@ -105,6 +105,8 @@ class CalledFunction: exception: BaseException | None = None def get_agent_task(self) -> "AgentTask" | None: + from ..pipeline.agent_task import AgentTask + assert self.task.done() if isinstance(self.result, tuple): @@ -115,6 +117,8 @@ def get_agent_task(self) -> "AgentTask" | None: return None def get_content(self) -> Any | None: + from ..pipeline.agent_task import AgentTask + assert self.task.done() if self.exception: @@ -130,7 +134,7 @@ def get_content(self) -> Any | None: def ai_callable( *, name: str | None = None, - description: str | _UseDocMarker | None = None, + description: str | _UseDocMarker | None = USE_DOCSTRING, auto_retry: bool = False, ) -> Callable: def deco(f): @@ -152,7 +156,7 @@ def ai_callable( self, *, name: str | None = None, - description: str | _UseDocMarker | None = None, + description: str | _UseDocMarker | None = USE_DOCSTRING, auto_retry: bool = True, ) -> Callable: def deco(f): diff --git a/livekit-agents/livekit/agents/pipeline/agent_task.py b/livekit-agents/livekit/agents/pipeline/agent_task.py index 10818a81a..35209d139 100644 --- a/livekit-agents/livekit/agents/pipeline/agent_task.py +++ b/livekit-agents/livekit/agents/pipeline/agent_task.py @@ -1,89 +1,65 @@ +from typing import TYPE_CHECKING, Callable, Optional, Union + from ..llm import LLM, ChatContext, FunctionContext -from ..llm.function_context import ( - METADATA_ATTR, - USE_DOCSTRING, - FunctionInfo, - ai_callable, -) +from ..llm.function_context import USE_DOCSTRING, FunctionInfo, ai_callable from ..stt import STT -# class TaskContext: -# def __init__(self, assistant: "VoicePipelineAgent"): -# self._assistant = assistant - -# @property -# def agent(self) -> "VoicePipelineAgent": -# return self._assistant - -# @property -# def user_data(self) -> dict[str, Any]: -# return self._assistant.user_data - -# @property -# def current_task(self) -> "AgentTask" | None: -# return self._assistant._current_task - -# @property -# def room(self) -> rtc.Room: -# if not hasattr(self._assistant, "_room"): -# raise ValueError("VoicePipelineAgent is not started") -# return self._assistant._room +if TYPE_CHECKING: + from ..pipeline import VoicePipelineAgent class AgentTask: def __init__( self, - instructions: str | None = None, - fnc_ctx: FunctionContext | None = None, - llm: LLM | None = None, - stt: STT | None = None, - name: str | None = None, + instructions: Optional[str] = None, + functions: Optional[list[Callable]] = None, + llm: Optional[LLM] = None, + name: Optional[str] = None, ) -> None: self._chat_ctx = ChatContext() if instructions: self._chat_ctx.append(text=instructions, role="system") - self._fnc_ctx = fnc_ctx - self._llm = llm - # TODO: support customized llm and stt - self._stt = stt - - self._task_name = name or self.__class__.__name__ + self._fnc_ctx: Optional[FunctionContext] = None + if functions: + self._fnc_ctx = FunctionContext() + for fnc in functions: + self._fnc_ctx._register_ai_function(fnc) - # enter method for transition - enter_fnc = self.enter - if not hasattr(enter_fnc, METADATA_ATTR): - enter_fnc = ai_callable( - name=f"enter_{self._task_name}", description=USE_DOCSTRING - )(self.enter) - - self._enter_fnc_info = FunctionContext._callable_to_fnc_info(enter_fnc) + self._llm = llm + self._stt = None + self._name = name or self.__class__.__name__ + self._enter_fnc_info = FunctionContext._callable_to_fnc_info(self.enter) + if not self._enter_fnc_info: + raise ValueError("enter function must be decorated with ai_callable") - def can_enter(self) -> bool: + def can_enter(self, agent: "VoicePipelineAgent") -> bool: return True - def enter(self) -> "AgentTask" | tuple["AgentTask", str]: + @ai_callable(name="enter_task", description=USE_DOCSTRING) + async def enter(self) -> Union["AgentTask", tuple["AgentTask", str]]: + """Called to enter the task.""" return self @property - def task_name(self) -> str: - return self._task_name + def name(self) -> str: + return self._name @property def enter_fnc_info(self) -> FunctionInfo: - return self._enter_fnc_info + return self._enter_fnc_info # type: ignore @property def chat_ctx(self) -> ChatContext: return self._chat_ctx @property - def fnc_ctx(self) -> FunctionContext | None: + def fnc_ctx(self) -> Optional[FunctionContext]: return self._fnc_ctx @property - def llm(self) -> LLM | None: + def llm(self) -> Optional[LLM]: return self._llm @property - def stt(self) -> STT | None: + def stt(self) -> Optional[STT]: return self._stt diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 7ac3efc31..ad779819f 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -334,12 +334,12 @@ def current_agent_task(self) -> AgentTask: return self._current_agent_task @current_agent_task.setter - def current_agent_task(self, task: AgentTask | None) -> None: + def current_agent_task(self, task: AgentTask) -> None: self._current_agent_task = task @property def fnc_ctx(self) -> FunctionContext | None: - available_tasks = [task for task in self._agent_tasks if task.can_enter()] + available_tasks = [task for task in self._agent_tasks if task.can_enter(self)] if not available_tasks: # no transition available, return the current function context return self._current_agent_task.fnc_ctx @@ -953,8 +953,8 @@ async def _execute_function_calls() -> None: logger.debug( "switching to next agent task", extra={ - "current_task": self.current_agent_task.task_name, - "new_task": new_task.task_name, + "current_task": self.current_agent_task.name, + "new_task": new_task.name, }, ) self.current_agent_task = new_task From 5d4910074d4cdaff5714331f59fc9ae348c0e538 Mon Sep 17 00:00:00 2001 From: Long Chen Date: Sat, 21 Dec 2024 18:56:01 +0800 Subject: [PATCH 07/13] add checkout task --- .../multi_stage_agent2.py | 66 +++++++++++++++---- 1 file changed, 54 insertions(+), 12 deletions(-) diff --git a/examples/voice-pipeline-agent/multi_stage_agent2.py b/examples/voice-pipeline-agent/multi_stage_agent2.py index a503f6d6a..548bd95a4 100644 --- a/examples/voice-pipeline-agent/multi_stage_agent2.py +++ b/examples/voice-pipeline-agent/multi_stage_agent2.py @@ -49,18 +49,16 @@ def _transfer_to(task: AgentTask, message: str | None = None) -> tuple[AgentTask class Greeter(AgentTask): - def __init__(self, menu: str = "Pizza, Salad, Ice Cream, Coffee"): + def __init__(self, menu: str): super().__init__( instructions=( "You are a professional restaurant receptionist handling incoming calls. " "Warmly greet the caller and ask if they would like to place an order. " f"Available menu items: {menu}. " - "Guide the conversation as follows:\n" - "- If they want to place an order, transfer them to order taking\n" - "- If they have completed their order, transfer them to customer details\n" - "- For any other inquiries, assist them directly\n" - "Maintain a friendly and professional tone throughout the conversation." - "Use the functions to transfer the call to the next step." + "Maintain a friendly and professional tone throughout the conversation.\n" + "Guide the conversation as follows: order taking, customer details, checkout. " + "Use the functions to transfer the call to OrderTaking, CustomerDetails, or Checkout. " + "For any other inquiries, assist them directly." ), ) @@ -88,7 +86,7 @@ async def enter(self) -> tuple[Self, str]: class OrderTaking(AgentTask): - def __init__(self, menu: str = "Pizza, Salad, Ice Cream, Coffee"): + def __init__(self, menu: str): super().__init__( instructions=( "You are a professional order taker at a restaurant. " @@ -96,8 +94,9 @@ def __init__(self, menu: str = "Pizza, Salad, Ice Cream, Coffee"): f"1. Take their order selections one at a time from our menu: {menu}\n" "2. Clarify any special requests or modifications\n" "3. Repeat back the complete order to confirm accuracy\n" - "4. Once confirmed, transfer them back to the greeter\n" + "4. Once confirmed, transfer them to collect customer details.\n" "Be attentive and ensure order accuracy before proceeding." + "Use the functions to transfer the call to the next step." ), functions=[self.update_order], ) @@ -139,8 +138,9 @@ def __init__(self): "1. Ask for the customer's name and confirm the spelling\n" "2. Request their phone number and verify it's correct\n" "3. Repeat both pieces of information back to ensure accuracy\n" - "4. Once confirmed, transfer back to the greeter\n" + "4. Once confirmed, transfer to checkout.\n" "Handle personal information professionally and courteously." + "Use the functions to transfer the call to the next step." ), functions=[self.collect_name, self.collect_phone], ) @@ -179,6 +179,48 @@ async def collect_phone( return f"The phone number is updated to {phone}" +class Checkout(AgentTask): + def __init__(self, menu: str): + super().__init__( + instructions=( + "You are a checkout agent. Ask the customer if they want to checkout. " + f"The menu items and prices are: {menu}. " + "If they confirm, call the checkout function and transfer them back to the greeter." + ), + functions=[self.checkout], + ) + + def can_enter(self, agent: "VoicePipelineAgent") -> bool: + checked_out = agent.user_data.get("checked_out", False) + order = agent.user_data.get("order", None) + customer_name = agent.user_data.get("customer_name", None) + customer_phone = agent.user_data.get("customer_phone", None) + + return order and customer_name and customer_phone and not checked_out + + @llm.ai_callable(name="enter_checkout") + async def enter(self) -> tuple[Self, str]: + """Called to transfer to the checkout.""" + agent = AgentCallContext.get_current().agent + message = f"Transferred from {agent.current_agent_task.name} to {self.name}. " + message += f"The current order is {agent.user_data.get('order', 'empty')}. " + message += f"The customer name is {agent.user_data.get('customer_name', 'unknown')}, " + message += f"phone number is {agent.user_data.get('customer_phone', 'unknown')}" + return _transfer_to(self, message) + + @llm.ai_callable() + async def checkout( + self, + expense: Annotated[float, llm.TypeInfo(description="The expense of the order")], + ) -> str: + """Called when the user confirms the checkout.""" + agent = AgentCallContext.get_current().agent + agent.user_data["checked_out"] = True + agent.user_data["expense"] = expense + logger.info("Checked out", extra=agent.user_data) + return "Checked out" + + def prewarm_process(proc: JobProcess): # preload silero VAD in memory to speed up session start proc.userdata["vad"] = silero.VAD.load() @@ -188,8 +230,8 @@ async def entrypoint(ctx: JobContext): await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) chat_log_file = "multi_stage_chat_log.txt" - menu = "Pizza, Salad, Ice Cream, Coffee" - agent_tasks = [Greeter(menu), OrderTaking(menu), CustomerDetails()] + menu = "Pizza: $10, Salad: $5, Ice Cream: $3, Coffee: $2" + agent_tasks = [Greeter(menu), OrderTaking(menu), CustomerDetails(), Checkout(menu)] participant = await ctx.wait_for_participant() agent = VoicePipelineAgent( From da9da7bc01663dd9f6ab0b14dd29a1f5468008c3 Mon Sep 17 00:00:00 2001 From: Long Chen Date: Sat, 21 Dec 2024 23:52:50 +0800 Subject: [PATCH 08/13] improve multi task example --- ...ti_stage_agent2.py => multi_task_agent.py} | 164 +++++++++++------- .../livekit/agents/pipeline/agent_task.py | 2 +- .../livekit/agents/pipeline/pipeline_agent.py | 17 +- 3 files changed, 106 insertions(+), 77 deletions(-) rename examples/voice-pipeline-agent/{multi_stage_agent2.py => multi_task_agent.py} (61%) diff --git a/examples/voice-pipeline-agent/multi_stage_agent2.py b/examples/voice-pipeline-agent/multi_task_agent.py similarity index 61% rename from examples/voice-pipeline-agent/multi_stage_agent2.py rename to examples/voice-pipeline-agent/multi_task_agent.py index 548bd95a4..86621beaa 100644 --- a/examples/voice-pipeline-agent/multi_stage_agent2.py +++ b/examples/voice-pipeline-agent/multi_task_agent.py @@ -32,6 +32,15 @@ def get_last_n_messages( return collected_messages +user_data_template = { + "order": [], + "customer_name": None, + "customer_phone": None, + "checked_out": False, + "expense": None, +} + + def _transfer_to(task: AgentTask, message: str | None = None) -> tuple[AgentTask, str]: agent = AgentCallContext.get_current().agent @@ -40,10 +49,14 @@ def _transfer_to(task: AgentTask, message: str | None = None) -> tuple[AgentTask task.chat_ctx.messages.extend( get_last_n_messages(agent.chat_ctx.messages, keep_last_n) ) + if not message: + user_data = user_data_template.copy() + user_data.update(agent.user_data) + message = ( + f"Transferred from {agent.current_agent_task.name} to {task.name}. " + f"The current user data is {json.dumps(user_data)}" + ) - message = ( - message or f"Transferred from {agent.current_agent_task.name} to {task.name}" - ) logger.info(message) return task, message @@ -52,51 +65,54 @@ class Greeter(AgentTask): def __init__(self, menu: str): super().__init__( instructions=( - "You are a professional restaurant receptionist handling incoming calls. " - "Warmly greet the caller and ask if they would like to place an order. " - f"Available menu items: {menu}. " - "Maintain a friendly and professional tone throughout the conversation.\n" - "Guide the conversation as follows: order taking, customer details, checkout. " - "Use the functions to transfer the call to OrderTaking, CustomerDetails, or Checkout. " - "For any other inquiries, assist them directly." + "You are a friendly restaurant receptionist. Your tasks:\n" + "1. Warmly greet the caller\n" + f"2. Ask if they'd like to place an order. (menu: {menu})\n" + "Transfer to:\n" + "- order_taking: when ready to place order\n" + "- customer_registration: only after order is complete\n" + "- checkout: only after customer details are collected\n\n" + "Important:\n" + "- If a transfer function is unavailable, it means prerequisites aren't met\n" + "- Guide the customer to complete previous steps first\n" + "- If already checked out, start a new order\n\n" + "For non-order inquiries, assist directly while maintaining a professional tone." ), + functions=[self.start_new_order], ) def can_enter(self, agent: "VoicePipelineAgent") -> bool: return True - @llm.ai_callable(name="enter_greeter") + @llm.ai_callable(name="transfer_to_greeter") async def enter(self) -> tuple[Self, str]: """Called to transfer to the greeter.""" + return _transfer_to(self) + @llm.ai_callable() + async def start_new_order(self) -> str: + """Called to start a new order.""" agent = AgentCallContext.get_current().agent - curr_task = agent.current_agent_task - - # return the collected information to the greeter - message = f"Transferred from {curr_task.name} to {self.name}. " - if isinstance(curr_task, OrderTaking): - message += f"The current order is {agent.user_data.get('order', 'empty')}" - elif isinstance(curr_task, CustomerDetails): - message += ( - f"The customer name is {agent.user_data.get('customer_name', 'unknown')}, " - f"phone number is {agent.user_data.get('customer_phone', 'unknown')}" - ) - - return _transfer_to(self, message) + agent.user_data.clear() + logger.info("Started a new order") + return "Started a new order" class OrderTaking(AgentTask): def __init__(self, menu: str): super().__init__( instructions=( - "You are a professional order taker at a restaurant. " - "Guide the customer through their order with these steps:\n" - f"1. Take their order selections one at a time from our menu: {menu}\n" - "2. Clarify any special requests or modifications\n" - "3. Repeat back the complete order to confirm accuracy\n" - "4. Once confirmed, transfer them to collect customer details.\n" - "Be attentive and ensure order accuracy before proceeding." - "Use the functions to transfer the call to the next step." + "You are a professional order taker at a restaurant. Your tasks:\n" + f"1. Take orders from our menu: {menu}\n" + "2. Clarify special requests\n" + "3. Confirm order accuracy\n\n" + "Transfer to:\n" + "- customer_registration: when order is confirmed\n" + "- greeter: for general questions or starting over\n\n" + "Important:\n" + "- Use update_order function to save the order\n" + "- Ensure order is complete before transferring to customer details\n" + "- For non-order questions, transfer to greeter" ), functions=[self.update_order], ) @@ -105,7 +121,7 @@ def can_enter(self, agent: "VoicePipelineAgent") -> bool: checked_out = agent.user_data.get("checked_out", False) return not checked_out - @llm.ai_callable(name="enter_order_taking") + @llm.ai_callable(name="transfer_to_order_taking") async def enter(self) -> tuple[Self, str]: """Called to transfer to the order taking.""" return _transfer_to(self) @@ -113,34 +129,36 @@ async def enter(self) -> tuple[Self, str]: @llm.ai_callable() async def update_order( self, - item: Annotated[ - str, - llm.TypeInfo( - description="The items of the full order, separated by commas" - ), + items: Annotated[ + list[str], + llm.TypeInfo(description="The items of the full order"), ], ) -> str: """Called when the user updates their order.""" agent = AgentCallContext.get_current().agent - agent.user_data["order"] = item + agent.user_data["order"] = items - logger.info("Updated order", extra={"order": item}) - return f"Updated order to {item}" + logger.info("Updated order", extra={"order": items}) + return f"Updated order to {items}" -class CustomerDetails(AgentTask): +class CustomerRegistration(AgentTask): def __init__(self): super().__init__( instructions=( - "You are collecting essential customer information for their order. " - "Follow these steps carefully:\n" - "1. Ask for the customer's name and confirm the spelling\n" - "2. Request their phone number and verify it's correct\n" + "You are collecting customer information for their order. Your tasks:\n" + "1. Get and confirm customer's name and comfirm the spelling\n" + "2. Get phone number and verify it's correct\n" "3. Repeat both pieces of information back to ensure accuracy\n" - "4. Once confirmed, transfer to checkout.\n" - "Handle personal information professionally and courteously." - "Use the functions to transfer the call to the next step." + "Transfer to:\n" + "- checkout: when all details are confirmed\n" + "- order_taking: to modify the order\n" + "- greeter: for general questions\n\n" + "Important:\n" + "- Use collect_name and collect_phone functions to save details\n" + "- Verify all information before proceeding to checkout\n" + "- For non-detail questions, transfer to greeter" ), functions=[self.collect_name, self.collect_phone], ) @@ -150,9 +168,9 @@ def can_enter(self, agent: "VoicePipelineAgent") -> bool: order = agent.user_data.get("order", None) return order and not checked_out - @llm.ai_callable(name="enter_customer_details") + @llm.ai_callable(name="transfer_to_customer_registration") async def enter(self) -> tuple[Self, str]: - """Called to transfer to the customer details.""" + """Called to transfer to the customer registration.""" return _transfer_to(self) @llm.ai_callable() @@ -183,29 +201,38 @@ class Checkout(AgentTask): def __init__(self, menu: str): super().__init__( instructions=( - "You are a checkout agent. Ask the customer if they want to checkout. " - f"The menu items and prices are: {menu}. " - "If they confirm, call the checkout function and transfer them back to the greeter." + "You are a checkout agent at a restaurant. Your tasks:\n" + f"1. Review order and prices ({menu})\n" + "2. Calculate and confirm total\n" + "3. Process checkout\n\n" + "Transfer to:\n" + "- order_taking: to modify order\n" + "- customer_registration: to update information\n" + "- greeter: after checkout or for general questions\n\n" + "Important:\n" + "- Use checkout function with final expense\n" + "- After successful checkout, transfer to greeter\n" + "- For non-checkout questions, transfer to greeter" ), functions=[self.checkout], ) def can_enter(self, agent: "VoicePipelineAgent") -> bool: checked_out = agent.user_data.get("checked_out", False) - order = agent.user_data.get("order", None) - customer_name = agent.user_data.get("customer_name", None) - customer_phone = agent.user_data.get("customer_phone", None) + order = agent.user_data.get("order") + customer_name = agent.user_data.get("customer_name") + customer_phone = agent.user_data.get("customer_phone") return order and customer_name and customer_phone and not checked_out - @llm.ai_callable(name="enter_checkout") + @llm.ai_callable(name="transfer_to_checkout") async def enter(self) -> tuple[Self, str]: """Called to transfer to the checkout.""" agent = AgentCallContext.get_current().agent message = f"Transferred from {agent.current_agent_task.name} to {self.name}. " - message += f"The current order is {agent.user_data.get('order', 'empty')}. " - message += f"The customer name is {agent.user_data.get('customer_name', 'unknown')}, " - message += f"phone number is {agent.user_data.get('customer_phone', 'unknown')}" + message += f"The current order is {agent.user_data.get('order')}. " + message += f"The customer name is {agent.user_data.get('customer_name')}, " + message += f"phone number is {agent.user_data.get('customer_phone')}" return _transfer_to(self, message) @llm.ai_callable() @@ -229,9 +256,14 @@ def prewarm_process(proc: JobProcess): async def entrypoint(ctx: JobContext): await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) - chat_log_file = "multi_stage_chat_log.txt" + chat_log_file = "multi_task_chat_log.txt" menu = "Pizza: $10, Salad: $5, Ice Cream: $3, Coffee: $2" - agent_tasks = [Greeter(menu), OrderTaking(menu), CustomerDetails(), Checkout(menu)] + agent_tasks = [ + Greeter(menu), + OrderTaking(menu), + CustomerRegistration(), + Checkout(menu), + ] participant = await ctx.wait_for_participant() agent = VoicePipelineAgent( @@ -241,7 +273,7 @@ async def entrypoint(ctx: JobContext): tts=openai.TTS(), initial_task=agent_tasks[0], available_tasks=agent_tasks, - max_nested_fnc_calls=2, # may call functions in the transition function + max_nested_fnc_calls=3, # may call functions in the transition function ) # For testing with text input @@ -280,9 +312,7 @@ def on_function_calls_finished(calls: list[llm.CalledFunction]): # Start the assistant. This will automatically publish a microphone track and listen to the participant. agent.start(ctx.room, participant) - await agent.say( - f"Welcome to our restaurant! We offer {menu}. How may I assist you today?" - ) + await agent.say("Welcome to our restaurant! How may I assist you today?") if __name__ == "__main__": diff --git a/livekit-agents/livekit/agents/pipeline/agent_task.py b/livekit-agents/livekit/agents/pipeline/agent_task.py index 35209d139..c851c8e30 100644 --- a/livekit-agents/livekit/agents/pipeline/agent_task.py +++ b/livekit-agents/livekit/agents/pipeline/agent_task.py @@ -35,7 +35,7 @@ def __init__( def can_enter(self, agent: "VoicePipelineAgent") -> bool: return True - @ai_callable(name="enter_task", description=USE_DOCSTRING) + @ai_callable(name="transfer_to_task", description=USE_DOCSTRING) async def enter(self) -> Union["AgentTask", tuple["AgentTask", str]]: """Called to enter the task.""" return self diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index ad779819f..d87adf4b3 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -888,6 +888,13 @@ async def _execute_function_calls() -> None: if not is_using_tools or interrupted: return + assert isinstance(speech_handle.source, LLMStream) + assert ( + not user_question or speech_handle.user_committed + ), "user speech should have been committed before using tools" + + llm_stream = speech_handle.source + if speech_handle.fnc_nested_depth >= self._opts.max_nested_fnc_calls: logger.warning( "max function calls nested depth reached", @@ -895,20 +902,12 @@ async def _execute_function_calls() -> None: "speech_id": speech_handle.id, "fnc_nested_depth": speech_handle.fnc_nested_depth, "fnc_names": [ - fnc.function_info.name - for fnc in speech_handle.source.function_calls + fnc.function_info.name for fnc in llm_stream.function_calls ], }, ) return - assert isinstance(speech_handle.source, LLMStream) - assert ( - not user_question or speech_handle.user_committed - ), "user speech should have been committed before using tools" - - llm_stream = speech_handle.source - # execute functions call_ctx = AgentCallContext(self, llm_stream) tk = _CallContextVar.set(call_ctx) From 3e7cede22c38a4486ed079f0cf1394efcf7c3ddf Mon Sep 17 00:00:00 2001 From: Long Chen Date: Sun, 22 Dec 2024 13:15:30 +0800 Subject: [PATCH 09/13] clean example --- .../voice-pipeline-agent/multi_stage_agent.py | 282 ------------------ .../voice-pipeline-agent/multi_task_agent.py | 42 +-- 2 files changed, 23 insertions(+), 301 deletions(-) delete mode 100644 examples/voice-pipeline-agent/multi_stage_agent.py diff --git a/examples/voice-pipeline-agent/multi_stage_agent.py b/examples/voice-pipeline-agent/multi_stage_agent.py deleted file mode 100644 index b52a2fc86..000000000 --- a/examples/voice-pipeline-agent/multi_stage_agent.py +++ /dev/null @@ -1,282 +0,0 @@ -import json -import logging -from dataclasses import dataclass -from typing import Annotated, Callable - -from dotenv import load_dotenv -from livekit import rtc -from livekit.agents import ( - AutoSubscribe, - JobContext, - JobProcess, - WorkerOptions, - cli, - llm, -) -from livekit.agents.pipeline import AgentCallContext, VoicePipelineAgent -from livekit.agents.stt import SpeechData, SpeechEvent, SpeechEventType -from livekit.plugins import deepgram, openai, silero - -load_dotenv() - -logger = logging.getLogger("multi-stage-agent") -logger.setLevel(logging.INFO) - - -@dataclass -class AgentSpec: - chat_ctx: llm.ChatContext - fnc_ctx: llm.FunctionContext - - @classmethod - def create(cls, instructions: str, fncs: list[Callable]): - chat_ctx = llm.ChatContext().append(text=instructions, role="system") - fnc_ctx = llm.FunctionContext() - for fnc in fncs: - fnc_ctx._register_ai_function(fnc) - return cls(chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) - - -class RestaurantBot: - def __init__(self, menu: str = "Pizza, Salad, Ice Cream, Coffee"): - self._menu = menu - self._specs = { - "Greeter": AgentSpec.create( - instructions=( - "You are a professional restaurant receptionist handling incoming calls. " - "Warmly greet the caller and ask if they would like to place an order. " - f"Available menu items: {self._menu}. " - "Guide the conversation as follows:\n" - "- If they want to place an order, transfer them to order taking\n" - "- If they have completed their order, transfer them to customer details\n" - "- For any other inquiries, assist them directly\n" - "Maintain a friendly and professional tone throughout the conversation." - "Use the functions to transfer the call to the next step." - ), - fncs=[ - self.transfer_to_ordering, - self.transfer_to_info_collection, - ], - ), - "OrderTaking": AgentSpec.create( - instructions=( - "You are a professional order taker at a restaurant. " - "Guide the customer through their order with these steps:\n" - f"1. Take their order selections one at a time from our menu: {self._menu}\n" - "2. Clarify any special requests or modifications\n" - "3. Repeat back the complete order to confirm accuracy\n" - "4. Once confirmed, transfer them back to the greeter\n" - "Be attentive and ensure order accuracy before proceeding." - ), - fncs=[ - self.update_order, - self.transfer_to_greeter, - ], - ), - "CustomerDetails": AgentSpec.create( - instructions=( - "You are collecting essential customer information for their order. " - "Follow these steps carefully:\n" - "1. Ask for the customer's name and confirm the spelling\n" - "2. Request their phone number and verify it's correct\n" - "3. Repeat both pieces of information back to ensure accuracy\n" - "4. Once confirmed, transfer back to the greeter\n" - "Handle personal information professionally and courteously." - ), - fncs=[ - self.collect_name, - self.collect_phone, - self.transfer_to_greeter, - ], - ), - } - - self._cur_spec = "Greeter" - self._order: str | None = None - self._customer_name: str | None = None - self._customer_phone: str | None = None - - @property - def spec(self) -> AgentSpec: - return self._specs[self._cur_spec] - - def _transfer_to_spec(self, spec_name: str, call_ctx: AgentCallContext) -> None: - agent = call_ctx.agent - - # keep the last n messages for the next stage - keep_last_n = 6 - prev_messages = agent.chat_ctx.messages.copy()[-keep_last_n:] - while prev_messages and prev_messages[0].role in ["system", "tool"]: - prev_messages.pop(0) - - self._cur_spec = spec_name - agent._fnc_ctx = self.spec.fnc_ctx - agent._chat_ctx = self.spec.chat_ctx - agent._chat_ctx.messages.extend(prev_messages) - - # use the new chat_ctx in the call_ctx - call_ctx.chat_ctx.messages = agent.chat_ctx.messages.copy() - logger.info(f"Transferred to {spec_name}") - - @llm.ai_callable() - async def update_order( - self, - item: Annotated[ - str, - llm.TypeInfo( - description="The items of the full order, separated by commas" - ), - ], - ): - """Called when the user updates their order.""" - self._order = item - logger.info("Updated order", extra={"order": item}) - return f"Updated order to {item}" - - @llm.ai_callable() - async def collect_name( - self, name: Annotated[str, llm.TypeInfo(description="The customer's name")] - ): - """Called when the user provides their name.""" - self._customer_name = name - logger.info("Collected name", extra={"customer_name": name}) - return f"The name is updated to {name}" - - @llm.ai_callable() - async def collect_phone( - self, - phone: Annotated[str, llm.TypeInfo(description="The customer's phone number")], - ): - """Called when the user provides their phone number.""" - # validate phone number (optional) - phone = phone.strip().replace("-", "") - if not phone.isdigit() or len(phone) != 10: - return "The phone number is invalid, it should be a 10-digit number, please try again." - - self._customer_phone = phone - logger.info("Collected phone", extra={"customer_phone": phone}) - return f"The phone number is updated to {phone}" - - @llm.ai_callable() - async def transfer_to_ordering(self): - """Called to transfer the call to order taking.""" - call_ctx = AgentCallContext.get_current() - - self._transfer_to_spec("OrderTaking", call_ctx) - await call_ctx.agent.say( - "I'll transfer you to our order taker who will help you with your selections." - ) - - return f"Transferred to order taking, the current order is {self._order}" - - @llm.ai_callable() - async def transfer_to_info_collection(self): - """Called to transfer the call to collect the customer's details.""" - call_ctx = AgentCallContext.get_current() - - self._transfer_to_spec("CustomerDetails", call_ctx) - await call_ctx.agent.say("Great! I'll collect your contact information now.") - return ( - f"Transferred to collecting customer details, " - f"the current collected name is {self._customer_name} " - f"and phone number is {self._customer_phone}" - ) - - @llm.ai_callable() - async def transfer_to_greeter( - self, - # summary: Annotated[ - # str, - # llm.TypeInfo( - # description="The summary of conversations in the current stage" - # ), - # ], - ): - """Called to transfer the call back to the greeter.""" - # message = f"Back to the greeter from {self._cur_spec}, the summary of conversations is {summary}" - message = f"Back to the greeter from {self._cur_spec}. " - if self._cur_spec == "OrderTaking": - message += f"The current order is {self._order}" - elif self._cur_spec == "CustomerDetails": - message += ( - f"The current collected name is {self._customer_name} " - f"and phone number is {self._customer_phone}" - ) - logger.info("Back to greeter", extra={"summary": message}) - - call_ctx = AgentCallContext.get_current() - self._transfer_to_spec("Greeter", call_ctx) - return message - - -def prewarm_process(proc: JobProcess): - # preload silero VAD in memory to speed up session start - proc.userdata["vad"] = silero.VAD.load() - - -async def entrypoint(ctx: JobContext): - await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) - - chat_log_file = "multi_stage_chat_log.txt" - menu = "Pizza, Salad, Ice Cream, Coffee" - multi_stage_ctx = RestaurantBot(menu) - - participant = await ctx.wait_for_participant() - agent = VoicePipelineAgent( - vad=ctx.proc.userdata["vad"], - stt=deepgram.STT(), - llm=openai.LLM(), - tts=openai.TTS(), - fnc_ctx=multi_stage_ctx.spec.fnc_ctx, - chat_ctx=multi_stage_ctx.spec.chat_ctx, - max_nested_fnc_calls=2, # may call functions in the transition function - ) - - # For testing with text input - @ctx.room.on("data_received") - def on_data_received(packet: rtc.DataPacket): - if packet.topic == "lk-chat-topic": - data = json.loads(packet.data.decode("utf-8")) - logger.info(f"Text input received: {data['message']}") - - agent._human_input.emit( - "final_transcript", - SpeechEvent( - type=SpeechEventType.END_OF_SPEECH, - alternatives=[SpeechData(language="en", text=data["message"])], - ), - ) - - @agent.on("user_speech_committed") - @agent.on("agent_speech_interrupted") - @agent.on("agent_speech_committed") - def on_speech_committed(message: llm.ChatMessage): - with open(chat_log_file, "a") as f: - f.write(f"{message.role}: {message.content}\n") - - @agent.on("function_calls_collected") - def on_function_calls_collected(calls: list[llm.FunctionCallInfo]): - fnc_infos = [{fnc.function_info.name: fnc.arguments} for fnc in calls] - with open(chat_log_file, "a") as f: - f.write(f"fnc_calls_collected: {fnc_infos}\n") - - @agent.on("function_calls_finished") - def on_function_calls_finished(calls: list[llm.CalledFunction]): - called_infos = [{fnc.call_info.function_info.name: fnc.result} for fnc in calls] - with open(chat_log_file, "a") as f: - f.write(f"fnc_calls_finished: {called_infos}\n") - - # Start the assistant. This will automatically publish a microphone track and listen to the participant. - agent.start(ctx.room, participant) - await agent.say( - f"Welcome to our restaurant! We offer {menu}. How may I assist you today?" - ) - - -if __name__ == "__main__": - cli.run_app( - WorkerOptions( - entrypoint_fnc=entrypoint, - prewarm_fnc=prewarm_process, - ), - ) diff --git a/examples/voice-pipeline-agent/multi_task_agent.py b/examples/voice-pipeline-agent/multi_task_agent.py index 86621beaa..274c2d3a8 100644 --- a/examples/voice-pipeline-agent/multi_task_agent.py +++ b/examples/voice-pipeline-agent/multi_task_agent.py @@ -19,7 +19,7 @@ load_dotenv() -logger = logging.getLogger("multi-stage-agent") +logger = logging.getLogger("multi-task-agent") logger.setLevel(logging.INFO) @@ -49,7 +49,9 @@ def _transfer_to(task: AgentTask, message: str | None = None) -> tuple[AgentTask task.chat_ctx.messages.extend( get_last_n_messages(agent.chat_ctx.messages, keep_last_n) ) + if not message: + # add the current user data to the message user_data = user_data_template.copy() user_data.update(agent.user_data) message = ( @@ -228,12 +230,7 @@ def can_enter(self, agent: "VoicePipelineAgent") -> bool: @llm.ai_callable(name="transfer_to_checkout") async def enter(self) -> tuple[Self, str]: """Called to transfer to the checkout.""" - agent = AgentCallContext.get_current().agent - message = f"Transferred from {agent.current_agent_task.name} to {self.name}. " - message += f"The current order is {agent.user_data.get('order')}. " - message += f"The customer name is {agent.user_data.get('customer_name')}, " - message += f"phone number is {agent.user_data.get('customer_phone')}" - return _transfer_to(self, message) + return _transfer_to(self) @llm.ai_callable() async def checkout( @@ -248,16 +245,20 @@ async def checkout( return "Checked out" -def prewarm_process(proc: JobProcess): - # preload silero VAD in memory to speed up session start - proc.userdata["vad"] = silero.VAD.load() - - async def entrypoint(ctx: JobContext): await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) chat_log_file = "multi_task_chat_log.txt" menu = "Pizza: $10, Salad: $5, Ice Cream: $3, Coffee: $2" + + # Set up chat logger + chat_logger = logging.getLogger("chat_logger") + chat_logger.setLevel(logging.INFO) + handler = logging.FileHandler(chat_log_file) + formatter = logging.Formatter("%(message)s") + handler.setFormatter(formatter) + chat_logger.addHandler(handler) + agent_tasks = [ Greeter(menu), OrderTaking(menu), @@ -276,7 +277,7 @@ async def entrypoint(ctx: JobContext): max_nested_fnc_calls=3, # may call functions in the transition function ) - # For testing with text input + # read text input from the room for easy testing @ctx.room.on("data_received") def on_data_received(packet: rtc.DataPacket): if packet.topic == "lk-chat-topic": @@ -291,30 +292,33 @@ def on_data_received(packet: rtc.DataPacket): ), ) + # write the chat log to a file @agent.on("user_speech_committed") @agent.on("agent_speech_interrupted") @agent.on("agent_speech_committed") def on_speech_committed(message: llm.ChatMessage): - with open(chat_log_file, "a") as f: - f.write(f"{message.role}: {message.content}\n") + chat_logger.info(f"{message.role}: {message.content}") @agent.on("function_calls_collected") def on_function_calls_collected(calls: list[llm.FunctionCallInfo]): fnc_infos = [{fnc.function_info.name: fnc.arguments} for fnc in calls] - with open(chat_log_file, "a") as f: - f.write(f"fnc_calls_collected: {fnc_infos}\n") + chat_logger.info(f"fnc_calls_collected: {fnc_infos}") @agent.on("function_calls_finished") def on_function_calls_finished(calls: list[llm.CalledFunction]): called_infos = [{fnc.call_info.function_info.name: fnc.result} for fnc in calls] - with open(chat_log_file, "a") as f: - f.write(f"fnc_calls_finished: {called_infos}\n") + chat_logger.info(f"fnc_calls_finished: {called_infos}") # Start the assistant. This will automatically publish a microphone track and listen to the participant. agent.start(ctx.room, participant) await agent.say("Welcome to our restaurant! How may I assist you today?") +def prewarm_process(proc: JobProcess): + # preload silero VAD in memory to speed up session start + proc.userdata["vad"] = silero.VAD.load() + + if __name__ == "__main__": cli.run_app( WorkerOptions( From 726ed329a07b1040e22549feaaa54a1429f8526a Mon Sep 17 00:00:00 2001 From: Long Chen Date: Sun, 22 Dec 2024 23:23:19 +0800 Subject: [PATCH 10/13] refactor the AgentTask --- .../voice-pipeline-agent/multi_task_agent.py | 104 +++++++++--------- .../livekit/agents/llm/function_context.py | 13 ++- .../livekit/agents/pipeline/agent_task.py | 85 +++++++++++--- .../livekit/agents/pipeline/pipeline_agent.py | 8 +- 4 files changed, 134 insertions(+), 76 deletions(-) diff --git a/examples/voice-pipeline-agent/multi_task_agent.py b/examples/voice-pipeline-agent/multi_task_agent.py index 274c2d3a8..05387a6dc 100644 --- a/examples/voice-pipeline-agent/multi_task_agent.py +++ b/examples/voice-pipeline-agent/multi_task_agent.py @@ -1,6 +1,6 @@ import json import logging -from typing import Annotated, Self +from typing import Annotated from dotenv import load_dotenv from livekit import rtc @@ -13,7 +13,11 @@ llm, ) from livekit.agents.pipeline import AgentCallContext, VoicePipelineAgent -from livekit.agents.pipeline.agent_task import AgentTask +from livekit.agents.pipeline.agent_task import ( + AgentTask, + AgentTaskOptions, + _default_before_enter_cb, +) from livekit.agents.stt import SpeechData, SpeechEvent, SpeechEventType from livekit.plugins import deepgram, openai, silero @@ -23,15 +27,6 @@ logger.setLevel(logging.INFO) -def get_last_n_messages( - messages: list[llm.ChatMessage], n: int -) -> list[llm.ChatMessage]: - collected_messages = messages.copy()[-n:] - while collected_messages and collected_messages[0].role in ["system", "tool"]: - collected_messages.pop(0) - return collected_messages - - user_data_template = { "order": [], "customer_name": None, @@ -41,24 +36,15 @@ def get_last_n_messages( } -def _transfer_to(task: AgentTask, message: str | None = None) -> tuple[AgentTask, str]: - agent = AgentCallContext.get_current().agent - - # keep the last n messages for the next stage - keep_last_n = 6 - task.chat_ctx.messages.extend( - get_last_n_messages(agent.chat_ctx.messages, keep_last_n) - ) - - if not message: - # add the current user data to the message - user_data = user_data_template.copy() - user_data.update(agent.user_data) - message = ( - f"Transferred from {agent.current_agent_task.name} to {task.name}. " - f"The current user data is {json.dumps(user_data)}" - ) +async def before_enter_cb( + agent: VoicePipelineAgent, task: AgentTask +) -> tuple[AgentTask, str]: + task, message = await _default_before_enter_cb(agent, task) + # additionally add the current user data to the message + user_data = user_data_template.copy() + user_data.update(agent.user_data) + message += f" The current user data is {json.dumps(user_data)}" logger.info(message) return task, message @@ -66,6 +52,7 @@ def _transfer_to(task: AgentTask, message: str | None = None) -> tuple[AgentTask class Greeter(AgentTask): def __init__(self, menu: str): super().__init__( + name="greeter", instructions=( "You are a friendly restaurant receptionist. Your tasks:\n" "1. Warmly greet the caller\n" @@ -81,16 +68,9 @@ def __init__(self, menu: str): "For non-order inquiries, assist directly while maintaining a professional tone." ), functions=[self.start_new_order], + options=AgentTaskOptions(before_enter_cb=before_enter_cb), ) - def can_enter(self, agent: "VoicePipelineAgent") -> bool: - return True - - @llm.ai_callable(name="transfer_to_greeter") - async def enter(self) -> tuple[Self, str]: - """Called to transfer to the greeter.""" - return _transfer_to(self) - @llm.ai_callable() async def start_new_order(self) -> str: """Called to start a new order.""" @@ -100,9 +80,32 @@ async def start_new_order(self) -> str: return "Started a new order" +""" +Another way to create a task + +@llm.ai_callable() +async def start_new_order() -> str: + ... + +def can_enter_greeter(agent: VoicePipelineAgent) -> bool: + return True + +greeter = AgentTask( + name="greeter", + instructions="...", + functions=[start_new_order], + options=AgentTaskOptions( + can_enter_cb=can_enter_greeter, + before_enter_cb=before_enter_cb, + ), +) +""" + + class OrderTaking(AgentTask): def __init__(self, menu: str): super().__init__( + name="order_taking", instructions=( "You are a professional order taker at a restaurant. Your tasks:\n" f"1. Take orders from our menu: {menu}\n" @@ -117,17 +120,16 @@ def __init__(self, menu: str): "- For non-order questions, transfer to greeter" ), functions=[self.update_order], + options=AgentTaskOptions( + can_enter_cb=self.can_enter, + before_enter_cb=before_enter_cb, + ), ) def can_enter(self, agent: "VoicePipelineAgent") -> bool: checked_out = agent.user_data.get("checked_out", False) return not checked_out - @llm.ai_callable(name="transfer_to_order_taking") - async def enter(self) -> tuple[Self, str]: - """Called to transfer to the order taking.""" - return _transfer_to(self) - @llm.ai_callable() async def update_order( self, @@ -148,6 +150,7 @@ async def update_order( class CustomerRegistration(AgentTask): def __init__(self): super().__init__( + name="customer_registration", instructions=( "You are collecting customer information for their order. Your tasks:\n" "1. Get and confirm customer's name and comfirm the spelling\n" @@ -163,6 +166,10 @@ def __init__(self): "- For non-detail questions, transfer to greeter" ), functions=[self.collect_name, self.collect_phone], + options=AgentTaskOptions( + can_enter_cb=self.can_enter, + before_enter_cb=before_enter_cb, + ), ) def can_enter(self, agent: "VoicePipelineAgent") -> bool: @@ -170,11 +177,6 @@ def can_enter(self, agent: "VoicePipelineAgent") -> bool: order = agent.user_data.get("order", None) return order and not checked_out - @llm.ai_callable(name="transfer_to_customer_registration") - async def enter(self) -> tuple[Self, str]: - """Called to transfer to the customer registration.""" - return _transfer_to(self) - @llm.ai_callable() async def collect_name( self, name: Annotated[str, llm.TypeInfo(description="The customer's name")] @@ -202,6 +204,7 @@ async def collect_phone( class Checkout(AgentTask): def __init__(self, menu: str): super().__init__( + name="checkout", instructions=( "You are a checkout agent at a restaurant. Your tasks:\n" f"1. Review order and prices ({menu})\n" @@ -217,6 +220,10 @@ def __init__(self, menu: str): "- For non-checkout questions, transfer to greeter" ), functions=[self.checkout], + options=AgentTaskOptions( + can_enter_cb=self.can_enter, + before_enter_cb=before_enter_cb, + ), ) def can_enter(self, agent: "VoicePipelineAgent") -> bool: @@ -227,11 +234,6 @@ def can_enter(self, agent: "VoicePipelineAgent") -> bool: return order and customer_name and customer_phone and not checked_out - @llm.ai_callable(name="transfer_to_checkout") - async def enter(self) -> tuple[Self, str]: - """Called to transfer to the checkout.""" - return _transfer_to(self) - @llm.ai_callable() async def checkout( self, diff --git a/livekit-agents/livekit/agents/llm/function_context.py b/livekit-agents/livekit/agents/llm/function_context.py index d523b7900..8d0e1d41a 100644 --- a/livekit-agents/livekit/agents/llm/function_context.py +++ b/livekit-agents/livekit/agents/llm/function_context.py @@ -34,6 +34,10 @@ class _UseDocMarker: pass +class _NoMetadataError(Exception): + pass + + METADATA_ATTR = "__livekit_ai_metadata__" USE_DOCSTRING = _UseDocMarker() @@ -166,9 +170,9 @@ def deco(f): return deco @staticmethod - def _callable_to_fnc_info(fnc: Callable) -> FunctionInfo | None: + def _callable_to_fnc_info(fnc: Callable) -> FunctionInfo: if not hasattr(fnc, METADATA_ATTR): - return None + raise _NoMetadataError("function must be decorated with ai_callable") metadata: _AIFncMetadata = getattr(fnc, METADATA_ATTR) fnc_name = metadata.name @@ -225,8 +229,9 @@ def _callable_to_fnc_info(fnc: Callable) -> FunctionInfo | None: ) def _register_ai_function(self, fnc: Callable) -> None: - fnc_info = self._callable_to_fnc_info(fnc) - if not fnc_info: + try: + fnc_info = self._callable_to_fnc_info(fnc) + except _NoMetadataError: logger.warning(f"function {fnc.__name__} does not have ai metadata") return diff --git a/livekit-agents/livekit/agents/pipeline/agent_task.py b/livekit-agents/livekit/agents/pipeline/agent_task.py index c851c8e30..3cd4be332 100644 --- a/livekit-agents/livekit/agents/pipeline/agent_task.py +++ b/livekit-agents/livekit/agents/pipeline/agent_task.py @@ -1,20 +1,61 @@ -from typing import TYPE_CHECKING, Callable, Optional, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Union -from ..llm import LLM, ChatContext, FunctionContext -from ..llm.function_context import USE_DOCSTRING, FunctionInfo, ai_callable +from ..llm import LLM, ChatContext, ChatMessage, FunctionContext +from ..llm.function_context import FunctionInfo, _UseDocMarker, ai_callable from ..stt import STT if TYPE_CHECKING: from ..pipeline import VoicePipelineAgent +BeforeEnterCallback = Callable[ + ["VoicePipelineAgent", "AgentTask"], + Awaitable[Union["AgentTask", tuple["AgentTask", str]]], +] + + +def _get_last_n_messages(messages: list[ChatMessage], n: int) -> list[ChatMessage]: + collected_messages = messages.copy()[-n:] + while collected_messages and collected_messages[0].role in ["system", "tool"]: + collected_messages.pop(0) + return collected_messages + + +async def _default_before_enter_cb( + agent: "VoicePipelineAgent", task: "AgentTask" +) -> tuple["AgentTask", str]: + # keep the last n messages for the next stage + keep_last_n = 6 + previous_messages = _get_last_n_messages(agent.chat_ctx.messages, keep_last_n) + task.chat_ctx.messages.extend(previous_messages) + + message = f"Transferred from {agent.current_agent_task.name} to {task.name}." + return task, message + + +def _default_can_enter_cb(agent: "VoicePipelineAgent") -> bool: + return True + + +@dataclass(frozen=True) +class AgentTaskOptions: + can_enter_cb: Callable[["VoicePipelineAgent"], bool] = _default_can_enter_cb + """callback to check if the task can be entered""" + before_enter_cb_description: Optional[Union[str, _UseDocMarker]] = None + """description of the before_enter callback, use `Called to transfer to {task_name}` if not provided""" + before_enter_cb: BeforeEnterCallback = _default_before_enter_cb + """callback to call before entering the task""" + + class AgentTask: def __init__( self, + name: Optional[str] = None, instructions: Optional[str] = None, functions: Optional[list[Callable]] = None, llm: Optional[LLM] = None, - name: Optional[str] = None, + options: AgentTaskOptions = AgentTaskOptions(), ) -> None: self._chat_ctx = ChatContext() if instructions: @@ -27,26 +68,36 @@ def __init__( self._llm = llm self._stt = None - self._name = name or self.__class__.__name__ - self._enter_fnc_info = FunctionContext._callable_to_fnc_info(self.enter) - if not self._enter_fnc_info: - raise ValueError("enter function must be decorated with ai_callable") - def can_enter(self, agent: "VoicePipelineAgent") -> bool: - return True + self._task_name = name or self.__class__.__name__ + self._opts = options + + # transfer function + from ..pipeline import AgentCallContext + + fnc_desc = ( + options.before_enter_cb_description + if options.before_enter_cb_description is not None + else f"Called to transfer to {self._task_name}" + ) + + @ai_callable(name=f"transfer_to_{self._task_name}", description=fnc_desc) + async def transfer_fnc() -> Union["AgentTask", tuple["AgentTask", str]]: + agent = AgentCallContext.get_current().agent + return await self._opts.before_enter_cb(agent, self) + + self._transfer_fnc_info = FunctionContext._callable_to_fnc_info(transfer_fnc) - @ai_callable(name="transfer_to_task", description=USE_DOCSTRING) - async def enter(self) -> Union["AgentTask", tuple["AgentTask", str]]: - """Called to enter the task.""" - return self + def _can_enter(self, agent: "VoicePipelineAgent") -> bool: + return self._opts.can_enter_cb(agent) @property def name(self) -> str: - return self._name + return self._task_name @property - def enter_fnc_info(self) -> FunctionInfo: - return self._enter_fnc_info # type: ignore + def transfer_fnc_info(self) -> FunctionInfo: + return self._transfer_fnc_info # type: ignore @property def chat_ctx(self) -> ChatContext: diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index d87adf4b3..345964451 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -339,7 +339,7 @@ def current_agent_task(self, task: AgentTask) -> None: @property def fnc_ctx(self) -> FunctionContext | None: - available_tasks = [task for task in self._agent_tasks if task.can_enter(self)] + available_tasks = [task for task in self._agent_tasks if task._can_enter(self)] if not available_tasks: # no transition available, return the current function context return self._current_agent_task.fnc_ctx @@ -350,11 +350,11 @@ def fnc_ctx(self) -> FunctionContext | None: else FunctionContext() ) for task in available_tasks: - if task.enter_fnc_info.name in new_fnc_ctx._fncs: + if task.transfer_fnc_info.name in new_fnc_ctx._fncs: raise ValueError( - f"duplicate ai_callable name: {task.enter_fnc_info.name}" + f"duplicate ai_callable name: {task.transfer_fnc_info.name}" ) - new_fnc_ctx._fncs[task.enter_fnc_info.name] = task.enter_fnc_info + new_fnc_ctx._fncs[task.transfer_fnc_info.name] = task.transfer_fnc_info return new_fnc_ctx From d16c8473d8f1db2d7fcd95bb63ce1422fd56a3b7 Mon Sep 17 00:00:00 2001 From: Long Chen Date: Thu, 26 Dec 2024 12:26:54 +0800 Subject: [PATCH 11/13] add news mailer example --- .../multi_task/news_mailer.py | 204 ++++++++++++++++++ .../restaurant_agent.py} | 0 2 files changed, 204 insertions(+) create mode 100644 examples/voice-pipeline-agent/multi_task/news_mailer.py rename examples/voice-pipeline-agent/{multi_task_agent.py => multi_task/restaurant_agent.py} (100%) diff --git a/examples/voice-pipeline-agent/multi_task/news_mailer.py b/examples/voice-pipeline-agent/multi_task/news_mailer.py new file mode 100644 index 000000000..6ea2fc8dd --- /dev/null +++ b/examples/voice-pipeline-agent/multi_task/news_mailer.py @@ -0,0 +1,204 @@ +import asyncio +import json +import logging +from typing import Annotated, TypedDict + +from dotenv import load_dotenv +from livekit import rtc +from livekit.agents import ( + AutoSubscribe, + JobContext, + JobProcess, + WorkerOptions, + cli, + llm, +) +from livekit.agents.pipeline import AgentCallContext, VoicePipelineAgent +from livekit.agents.pipeline.agent_task import AgentTask +from livekit.agents.stt import SpeechData, SpeechEvent, SpeechEventType +from livekit.plugins import deepgram, openai, silero + +load_dotenv() + +logger = logging.getLogger("news-mailer") +logger.setLevel(logging.INFO) + + +class UserData(TypedDict): + query: str | None + news: str | None + email: str | None + + +@llm.ai_callable() +async def query_news( + query: Annotated[str, llm.TypeInfo(description="The query user asked for")], +) -> str: + """Called to query news from the internet. + Tell the user you are checking the news when calling this function.""" + logger.info(f"Querying news for {query}") + perplexity_llm = openai.LLM.with_perplexity( + model="llama-3.1-sonar-small-128k-online" + ) + chat_ctx = llm.ChatContext().append( + role="system", + text="Search the recent news articles about the query.", + ) + chat_ctx.append(role="user", text=query) + llm_stream = perplexity_llm.chat(chat_ctx=chat_ctx) + news = "" + async for chunk in llm_stream: + if not chunk or not chunk.choices or not chunk.choices[0].delta.content: + continue + news += chunk.choices[0].delta.content + + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + user_data["query"] = query + user_data["news"] = news + logger.info(f"The news about {query} collected") + return news + + +@llm.ai_callable() +async def send_news_email() -> str: + """Called to send the news to the user's email address.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + email = user_data.get("email") + news = user_data.get("news") + + if not email: + return "email is not collected" + + if not news: + return "news is not collected" + + # mock sending email + query = user_data.get("query") + logger.info(f"Sending news about {query} to {email}") + await asyncio.sleep(2) + return f"The news about {query} is sent to {email}" + + +@llm.ai_callable() +async def verify_email( + email: Annotated[str, llm.TypeInfo(description="The collected email address")], +) -> str: + """Called to verify the user's email address.""" + if "@" not in email: + return "The email address is not valid, please confirm with the user." + + # Potentially show the email on the screen for the user to confirm + return "The email address is valid. Please confirm with the user for the spelling." + + +@llm.ai_callable() +async def update_email( + email: Annotated[str, llm.TypeInfo(description="The collected email address")], +) -> str: + """Called to update the user's email address.""" + + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + user_data["email"] = email + logger.info(f"The email is updated to {email}") + return f"The email is updated to {email}." + + +news_mailer = AgentTask( + name="news_mailer", + instructions=( + "You are a friendly assistant that can query news from the internet." + "Summarize the news in 50 words or less and ask the user if they want to receive the news by email." + "Use email_collector to collect the user's email address." + ), + functions=[query_news, send_news_email], +) + +email_collector = AgentTask( + name="email_collector", + instructions=( + "You are a friendly assistant that can collect the user's email address. Your tasks:\n" + "1. Collect the user's email address, help to complete the @ and domain part if possible.\n" + "2. Verify the address with `verify_email` function until the user confirms.\n" + "3. Update the email address after the user confirms.\n" + "Transfer back to news_mailer after the email is updated." + ), + functions=[update_email, verify_email], +) + + +async def entrypoint(ctx: JobContext): + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + + chat_log_file = "news_mailer.log" + + # Set up chat logger + chat_logger = logging.getLogger("chat_logger") + chat_logger.setLevel(logging.INFO) + handler = logging.FileHandler(chat_log_file) + formatter = logging.Formatter("%(message)s") + handler.setFormatter(formatter) + chat_logger.addHandler(handler) + + participant = await ctx.wait_for_participant() + agent = VoicePipelineAgent( + vad=ctx.proc.userdata["vad"], + stt=deepgram.STT(), + llm=openai.LLM(), + tts=openai.TTS(), + initial_task=news_mailer, + available_tasks=[news_mailer, email_collector], + max_nested_fnc_calls=3, # may call functions in the transition function + ) + + # read text input from the room for easy testing + @ctx.room.on("data_received") + def on_data_received(packet: rtc.DataPacket): + if packet.topic == "lk-chat-topic": + data = json.loads(packet.data.decode("utf-8")) + logger.debug("Text input received", extra={"message": data["message"]}) + + agent._human_input.emit( + "final_transcript", + SpeechEvent( + type=SpeechEventType.END_OF_SPEECH, + alternatives=[SpeechData(language="en", text=data["message"])], + ), + ) + + # write the chat log to a file + @agent.on("user_speech_committed") + @agent.on("agent_speech_interrupted") + @agent.on("agent_speech_committed") + def on_speech_committed(message: llm.ChatMessage): + chat_logger.info(f"{message.role}: {message.content}") + + @agent.on("function_calls_collected") + def on_function_calls_collected(calls: list[llm.FunctionCallInfo]): + fnc_infos = [{fnc.function_info.name: fnc.arguments} for fnc in calls] + chat_logger.info(f"fnc_calls_collected: {fnc_infos}") + + @agent.on("function_calls_finished") + def on_function_calls_finished(calls: list[llm.CalledFunction]): + called_infos = [{fnc.call_info.function_info.name: fnc.result} for fnc in calls] + chat_logger.info(f"fnc_calls_finished: {called_infos}") + + # Start the assistant. This will automatically publish a microphone track and listen to the participant. + agent.start(ctx.room, participant) + await agent.say("Welcome to news mailer! How may I assist you today?") + + +def prewarm_process(proc: JobProcess): + # preload silero VAD in memory to speed up session start + proc.userdata["vad"] = silero.VAD.load() + + +if __name__ == "__main__": + cli.run_app( + WorkerOptions( + entrypoint_fnc=entrypoint, + prewarm_fnc=prewarm_process, + ), + ) diff --git a/examples/voice-pipeline-agent/multi_task_agent.py b/examples/voice-pipeline-agent/multi_task/restaurant_agent.py similarity index 100% rename from examples/voice-pipeline-agent/multi_task_agent.py rename to examples/voice-pipeline-agent/multi_task/restaurant_agent.py From 92ab58dbfc5c0fbb7c4d2e9cfe13c2b16bd5a104 Mon Sep 17 00:00:00 2001 From: Long Chen Date: Thu, 26 Dec 2024 23:46:49 +0800 Subject: [PATCH 12/13] update restaurant instrunctions --- .../multi_task/restaurant_agent.py | 66 +++++++++---------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/examples/voice-pipeline-agent/multi_task/restaurant_agent.py b/examples/voice-pipeline-agent/multi_task/restaurant_agent.py index 05387a6dc..501ee5e82 100644 --- a/examples/voice-pipeline-agent/multi_task/restaurant_agent.py +++ b/examples/voice-pipeline-agent/multi_task/restaurant_agent.py @@ -54,13 +54,10 @@ def __init__(self, menu: str): super().__init__( name="greeter", instructions=( - "You are a friendly restaurant receptionist. Your tasks:\n" + "You are a friendly restaurant receptionist. Your jobs are:\n" "1. Warmly greet the caller\n" f"2. Ask if they'd like to place an order. (menu: {menu})\n" - "Transfer to:\n" - "- order_taking: when ready to place order\n" - "- customer_registration: only after order is complete\n" - "- checkout: only after customer details are collected\n\n" + "3. Transfer to the corresponding task using functions based on the user's response.\n" "Important:\n" "- If a transfer function is unavailable, it means prerequisites aren't met\n" "- Guide the customer to complete previous steps first\n" @@ -68,7 +65,13 @@ def __init__(self, menu: str): "For non-order inquiries, assist directly while maintaining a professional tone." ), functions=[self.start_new_order], - options=AgentTaskOptions(before_enter_cb=before_enter_cb), + options=AgentTaskOptions( + before_enter_cb=before_enter_cb, + before_enter_cb_description=( + "Called to transfer to the greeter when the user asks for general questions " + "or starting over after checking out." + ), + ), ) @llm.ai_callable() @@ -76,6 +79,7 @@ async def start_new_order(self) -> str: """Called to start a new order.""" agent = AgentCallContext.get_current().agent agent.user_data.clear() + # probably also clear the chat ctx of tasks logger.info("Started a new order") return "Started a new order" @@ -107,22 +111,20 @@ def __init__(self, menu: str): super().__init__( name="order_taking", instructions=( - "You are a professional order taker at a restaurant. Your tasks:\n" + "You are a professional order taker at a restaurant. Your jobs are:\n" f"1. Take orders from our menu: {menu}\n" "2. Clarify special requests\n" "3. Confirm order accuracy\n\n" - "Transfer to:\n" - "- customer_registration: when order is confirmed\n" - "- greeter: for general questions or starting over\n\n" - "Important:\n" - "- Use update_order function to save the order\n" - "- Ensure order is complete before transferring to customer details\n" - "- For non-order questions, transfer to greeter" + "Transfer to the next step using functions after the order is confirmed." ), functions=[self.update_order], options=AgentTaskOptions( can_enter_cb=self.can_enter, before_enter_cb=before_enter_cb, + before_enter_cb_description=( + "Called to transfer to the order taking " + "when the user wants to take an order or modify their order." + ), ), ) @@ -152,23 +154,20 @@ def __init__(self): super().__init__( name="customer_registration", instructions=( - "You are collecting customer information for their order. Your tasks:\n" + "You are collecting customer information for their order. Your jobs are:\n" "1. Get and confirm customer's name and comfirm the spelling\n" "2. Get phone number and verify it's correct\n" - "3. Repeat both pieces of information back to ensure accuracy\n" - "Transfer to:\n" - "- checkout: when all details are confirmed\n" - "- order_taking: to modify the order\n" - "- greeter: for general questions\n\n" - "Important:\n" - "- Use collect_name and collect_phone functions to save details\n" - "- Verify all information before proceeding to checkout\n" - "- For non-detail questions, transfer to greeter" + "3. Repeat both pieces of information back to ensure accuracy\n\n" + "Transfer to the next step using functions after the information is confirmed." ), functions=[self.collect_name, self.collect_phone], options=AgentTaskOptions( can_enter_cb=self.can_enter, before_enter_cb=before_enter_cb, + before_enter_cb_description=( + "Called to transfer to the customer registration " + "after the order is confirmed or the user wants to update their information." + ), ), ) @@ -206,23 +205,20 @@ def __init__(self, menu: str): super().__init__( name="checkout", instructions=( - "You are a checkout agent at a restaurant. Your tasks:\n" + "You are a checkout agent at a restaurant. Your jobs are:\n" f"1. Review order and prices ({menu})\n" "2. Calculate and confirm total\n" - "3. Process checkout\n\n" - "Transfer to:\n" - "- order_taking: to modify order\n" - "- customer_registration: to update information\n" - "- greeter: after checkout or for general questions\n\n" - "Important:\n" - "- Use checkout function with final expense\n" - "- After successful checkout, transfer to greeter\n" - "- For non-checkout questions, transfer to greeter" + "3. Process checkout and confirm the total\n\n" + "Transfer back to the greeter using functions after checkout." ), functions=[self.checkout], options=AgentTaskOptions( can_enter_cb=self.can_enter, before_enter_cb=before_enter_cb, + before_enter_cb_description=( + "Called to transfer to the checkout " + "after the user confirms the order and registration." + ), ), ) @@ -250,7 +246,7 @@ async def checkout( async def entrypoint(ctx: JobContext): await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) - chat_log_file = "multi_task_chat_log.txt" + chat_log_file = "restaurant_agent.log" menu = "Pizza: $10, Salad: $5, Ice Cream: $3, Coffee: $2" # Set up chat logger From a94f06dfdffb87d49168c43f749f822e92f4e467 Mon Sep 17 00:00:00 2001 From: Long Chen Date: Thu, 26 Dec 2024 23:54:42 +0800 Subject: [PATCH 13/13] rename to transfer_function_description --- .../multi_task/restaurant_agent.py | 8 ++++---- .../livekit/agents/pipeline/agent_task.py | 14 ++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/examples/voice-pipeline-agent/multi_task/restaurant_agent.py b/examples/voice-pipeline-agent/multi_task/restaurant_agent.py index 501ee5e82..9ab5649bc 100644 --- a/examples/voice-pipeline-agent/multi_task/restaurant_agent.py +++ b/examples/voice-pipeline-agent/multi_task/restaurant_agent.py @@ -67,7 +67,7 @@ def __init__(self, menu: str): functions=[self.start_new_order], options=AgentTaskOptions( before_enter_cb=before_enter_cb, - before_enter_cb_description=( + transfer_function_description=( "Called to transfer to the greeter when the user asks for general questions " "or starting over after checking out." ), @@ -121,7 +121,7 @@ def __init__(self, menu: str): options=AgentTaskOptions( can_enter_cb=self.can_enter, before_enter_cb=before_enter_cb, - before_enter_cb_description=( + transfer_function_description=( "Called to transfer to the order taking " "when the user wants to take an order or modify their order." ), @@ -164,7 +164,7 @@ def __init__(self): options=AgentTaskOptions( can_enter_cb=self.can_enter, before_enter_cb=before_enter_cb, - before_enter_cb_description=( + transfer_function_description=( "Called to transfer to the customer registration " "after the order is confirmed or the user wants to update their information." ), @@ -215,7 +215,7 @@ def __init__(self, menu: str): options=AgentTaskOptions( can_enter_cb=self.can_enter, before_enter_cb=before_enter_cb, - before_enter_cb_description=( + transfer_function_description=( "Called to transfer to the checkout " "after the user confirms the order and registration." ), diff --git a/livekit-agents/livekit/agents/pipeline/agent_task.py b/livekit-agents/livekit/agents/pipeline/agent_task.py index 3cd4be332..3747b2bbf 100644 --- a/livekit-agents/livekit/agents/pipeline/agent_task.py +++ b/livekit-agents/livekit/agents/pipeline/agent_task.py @@ -42,8 +42,8 @@ def _default_can_enter_cb(agent: "VoicePipelineAgent") -> bool: class AgentTaskOptions: can_enter_cb: Callable[["VoicePipelineAgent"], bool] = _default_can_enter_cb """callback to check if the task can be entered""" - before_enter_cb_description: Optional[Union[str, _UseDocMarker]] = None - """description of the before_enter callback, use `Called to transfer to {task_name}` if not provided""" + transfer_function_description: Optional[Union[str, _UseDocMarker]] = None + """description of the transfer function, use `Called to transfer to {task_name}` if not provided""" before_enter_cb: BeforeEnterCallback = _default_before_enter_cb """callback to call before entering the task""" @@ -75,13 +75,15 @@ def __init__( # transfer function from ..pipeline import AgentCallContext - fnc_desc = ( - options.before_enter_cb_description - if options.before_enter_cb_description is not None + transfer_fnc_desc = ( + options.transfer_function_description + if options.transfer_function_description is not None else f"Called to transfer to {self._task_name}" ) - @ai_callable(name=f"transfer_to_{self._task_name}", description=fnc_desc) + @ai_callable( + name=f"transfer_to_{self._task_name}", description=transfer_fnc_desc + ) async def transfer_fnc() -> Union["AgentTask", tuple["AgentTask", str]]: agent = AgentCallContext.get_current().agent return await self._opts.before_enter_cb(agent, self)