diff --git a/examples/multimodal-agent/openai_multitask.py b/examples/multimodal-agent/openai_multitask.py new file mode 100644 index 000000000..ed6ae70cc --- /dev/null +++ b/examples/multimodal-agent/openai_multitask.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import Annotated, Optional, TypedDict + +from dotenv import load_dotenv +from livekit.agents import ( + AutoSubscribe, + JobContext, + WorkerOptions, + WorkerType, + cli, + llm, + multimodal, +) +from livekit.agents.pipeline import AgentTask +from livekit.plugins import openai +from livekit.plugins.openai.realtime import RealtimeCallContext + +load_dotenv() + +logger = logging.getLogger("my-worker") +logger.setLevel(logging.INFO) + + +def update_context(task: AgentTask, chat_ctx: llm.ChatContext) -> None: + # last_chat_ctx = chat_ctx.truncate(keep_last_n=4, keep_tool_calls=False) + # task.inject_chat_ctx(last_chat_ctx) + pass + + +class UserData(TypedDict): + customer_name: Optional[str] + customer_phone: Optional[str] + + reservation_time: Optional[str] + + order: Optional[list[str]] + customer_credit_card: Optional[str] + customer_credit_card_expiry: Optional[str] + customer_credit_card_cvv: Optional[str] + expense: Optional[float] + checked_out: Optional[bool] + + +# some common functions +@llm.ai_callable() +async def update_name( + name: Annotated[str, llm.TypeInfo(description="The customer's name")], +) -> str: + """Called when the user provides their name. + Confirm the spelling with the user before calling the function.""" + session = RealtimeCallContext.get_current().session + user_data: UserData = session.user_data + user_data["customer_name"] = name + return f"The name is updated to {name}" + + +@llm.ai_callable() +async def update_phone( + phone: Annotated[str, llm.TypeInfo(description="The customer's phone number")], +) -> str: + """Called when the user provides their phone number. + Confirm the spelling with the user before calling the function.""" + + session = RealtimeCallContext.get_current().session + user_data: UserData = session.user_data + user_data["customer_phone"] = phone + return f"The phone number is updated to {phone}" + + +@llm.ai_callable() +async def to_greeter() -> tuple[AgentTask, str]: + """Called when user asks any unrelated questions or requests any other services not in your job description.""" + session = RealtimeCallContext.get_current().session + next_task = AgentTask.get_task("greeter") + update_context(next_task, session.chat_ctx_copy()) + return next_task, f"User data: {session.user_data}" + + +class Greeter(AgentTask): + def __init__(self, menu: str): + super().__init__( + instructions=( + f"You are a friendly restaurant receptionist. The menu is: {menu}\n" + "Your jobs are to greet the caller and understand if they want to " + "make a reservation or order takeaway. Guide them to the right agent." + ) + ) + self.menu = menu + + @llm.ai_callable() + async def to_reservation(self) -> tuple[AgentTask, str]: + """Called when user wants to make a reservation. This function handles transitioning to the reservation agent + who will collect the necessary details like reservation time, customer name and phone number.""" + session = RealtimeCallContext.get_current().session + next_task = AgentTask.get_task("reservation") + update_context(next_task, session.chat_ctx_copy()) + return next_task, f"User info: {session.user_data}" + + @llm.ai_callable() + async def to_takeaway(self) -> tuple[AgentTask, str]: + """Called when the user wants to place a takeaway order. This includes handling orders for pickup, + delivery, or when the user wants to proceed to checkout with their existing order.""" + session = RealtimeCallContext.get_current().session + next_task = AgentTask.get_task("takeaway") + update_context(next_task, session.chat_ctx_copy()) + return next_task, f"User info: {session.user_data}" + + +class Reservation(AgentTask): + def __init__(self): + super().__init__( + instructions=( + "You are a reservation agent at a restaurant. Your jobs are to ask for " + "the reservation time, then customer's name, and phone number. Then " + "confirm the reservation details with the customer." + ), + functions=[update_name, update_phone, to_greeter], + ) + + @llm.ai_callable() + async def update_reservation_time( + self, + time: Annotated[str, llm.TypeInfo(description="The reservation time")], + ) -> str: + """Called when the user provides their reservation time. + Confirm the time with the user before calling the function.""" + session = RealtimeCallContext.get_current().session + user_data: UserData = session.user_data + user_data["reservation_time"] = time + return f"The reservation time is updated to {time}" + + @llm.ai_callable() + async def confirm_reservation(self) -> str: + """Called when the user confirms the reservation. + Call this function to transfer to the next step.""" + session = RealtimeCallContext.get_current().session + user_data: UserData = session.user_data + if not user_data.get("customer_name") or not user_data.get("customer_phone"): + return "Please provide your name and phone number first." + + if not user_data.get("reservation_time"): + return "Please provide reservation time first." + + next_task = AgentTask.get_task("greeter") + update_context(next_task, session.chat_ctx_copy()) + return next_task, f"Reservation confirmed. User data: {user_data}" + + +class Takeaway(AgentTask): + def __init__(self, menu: str): + super().__init__( + instructions=( + f"Our menu is: {menu}. Your jobs are to record the order from the " + "customer. Clarify special requests and confirm the order with the " + "customer." + ), + functions=[to_greeter], + ) + + @llm.ai_callable() + async def update_order( + self, + items: Annotated[ + list[str], llm.TypeInfo(description="The items of the full order") + ], + ) -> str: + """Called when the user create or update their order.""" + session = RealtimeCallContext.get_current().session + user_data: UserData = session.user_data + user_data["order"] = items + return f"Updated order to {items}" + + @llm.ai_callable() + async def to_checkout(self) -> tuple[AgentTask, str]: + """Called when the user confirms the order. Call this function to transfer to the checkout step. + Double check the order with the user before calling the function.""" + session = RealtimeCallContext.get_current().session + user_data: UserData = session.user_data + if not user_data.get("order"): + return "No takeaway order found. Please make an order first." + + next_task = AgentTask.get_task("checkout") + update_context(next_task, session.chat_ctx_copy()) + return next_task, f"User info: {user_data}" + + +class Checkout(AgentTask): + def __init__(self, menu: str): + super().__init__( + instructions=( + "You are a professional checkout agent at a restaurant. The menu is: " + f"{menu}. Your are responsible for confirming the expense of the " + "order and then collecting customer's name, phone number and credit card " + "information, including the card number, expiry date, and CVV step by step." + ), + functions=[update_name, update_phone, to_greeter], + ) + + @llm.ai_callable() + async def confirm_expense( + self, + expense: Annotated[float, llm.TypeInfo(description="The expense of the order")], + ) -> str: + """Called when the user confirms the expense.""" + session = RealtimeCallContext.get_current().session + user_data: UserData = session.user_data + user_data["expense"] = expense + return f"The expense is confirmed to be {expense}" + + @llm.ai_callable() + async def update_credit_card( + self, + number: Annotated[str, llm.TypeInfo(description="The credit card number")], + expiry: Annotated[ + str, llm.TypeInfo(description="The expiry date of the credit card") + ], + cvv: Annotated[str, llm.TypeInfo(description="The CVV of the credit card")], + ) -> str: + """Called when the user provides their credit card number, expiry date, and CVV. + Confirm the spelling with the user before calling the function.""" + session = RealtimeCallContext.get_current().session + user_data: UserData = session.user_data + user_data["customer_credit_card"] = number + user_data["customer_credit_card_expiry"] = expiry + user_data["customer_credit_card_cvv"] = cvv + return f"The credit card number is updated to {number}" + + @llm.ai_callable() + async def confirm_checkout(self) -> str: + """Called when the user confirms the checkout. + Double check the information with the user before calling the function.""" + session = RealtimeCallContext.get_current().session + user_data: UserData = session.user_data + if not user_data.get("expense"): + return "Please confirm the expense first." + + if ( + not user_data.get("customer_credit_card") + or not user_data.get("customer_credit_card_expiry") + or not user_data.get("customer_credit_card_cvv") + ): + return "Please provide the credit card information first." + + user_data["checked_out"] = True + next_task = AgentTask.get_task("greeter") + update_context(next_task, session.chat_ctx_copy()) + return next_task, f"User checked out. User info: {user_data}" + + @llm.ai_callable() + async def to_takeaway(self) -> tuple[AgentTask, str]: + """Called when the user wants to update their order.""" + session = RealtimeCallContext.get_current().session + next_task = AgentTask.get_task("takeaway") + update_context(next_task, session.chat_ctx_copy()) + return next_task, f"User info: {session.user_data}" + + +async def entrypoint(ctx: JobContext): + logger.info("starting entrypoint") + + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + participant = await ctx.wait_for_participant() + + # create tasks + menu = "Pizza: $10, Salad: $5, Ice Cream: $3, Coffee: $2" + greeter = AgentTask.register_task(Greeter(menu), "greeter") + AgentTask.register_task(Reservation(), "reservation") + AgentTask.register_task(Takeaway(menu), "takeaway") + AgentTask.register_task(Checkout(menu), "checkout") + + agent = multimodal.MultimodalAgent( + model=openai.realtime.RealtimeModel( + voice="alloy", + temperature=0.8, + instructions=greeter.instructions, + turn_detection=openai.realtime.ServerVadOptions( + threshold=0.6, prefix_padding_ms=200, silence_duration_ms=500 + ), + ), + initial_task=greeter, + ) + agent.start(ctx.room, participant) + + await asyncio.sleep(1) + session: openai.realtime.RealtimeSession = agent._session + session.response.create() + + +if __name__ == "__main__": + cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, worker_type=WorkerType.ROOM)) diff --git a/examples/voice-pipeline-agent/multi_task/news_mailer.py b/examples/voice-pipeline-agent/multi_task/news_mailer.py new file mode 100644 index 000000000..6ea2fc8dd --- /dev/null +++ b/examples/voice-pipeline-agent/multi_task/news_mailer.py @@ -0,0 +1,204 @@ +import asyncio +import json +import logging +from typing import Annotated, TypedDict + +from dotenv import load_dotenv +from livekit import rtc +from livekit.agents import ( + AutoSubscribe, + JobContext, + JobProcess, + WorkerOptions, + cli, + llm, +) +from livekit.agents.pipeline import AgentCallContext, VoicePipelineAgent +from livekit.agents.pipeline.agent_task import AgentTask +from livekit.agents.stt import SpeechData, SpeechEvent, SpeechEventType +from livekit.plugins import deepgram, openai, silero + +load_dotenv() + +logger = logging.getLogger("news-mailer") +logger.setLevel(logging.INFO) + + +class UserData(TypedDict): + query: str | None + news: str | None + email: str | None + + +@llm.ai_callable() +async def query_news( + query: Annotated[str, llm.TypeInfo(description="The query user asked for")], +) -> str: + """Called to query news from the internet. + Tell the user you are checking the news when calling this function.""" + logger.info(f"Querying news for {query}") + perplexity_llm = openai.LLM.with_perplexity( + model="llama-3.1-sonar-small-128k-online" + ) + chat_ctx = llm.ChatContext().append( + role="system", + text="Search the recent news articles about the query.", + ) + chat_ctx.append(role="user", text=query) + llm_stream = perplexity_llm.chat(chat_ctx=chat_ctx) + news = "" + async for chunk in llm_stream: + if not chunk or not chunk.choices or not chunk.choices[0].delta.content: + continue + news += chunk.choices[0].delta.content + + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + user_data["query"] = query + user_data["news"] = news + logger.info(f"The news about {query} collected") + return news + + +@llm.ai_callable() +async def send_news_email() -> str: + """Called to send the news to the user's email address.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + email = user_data.get("email") + news = user_data.get("news") + + if not email: + return "email is not collected" + + if not news: + return "news is not collected" + + # mock sending email + query = user_data.get("query") + logger.info(f"Sending news about {query} to {email}") + await asyncio.sleep(2) + return f"The news about {query} is sent to {email}" + + +@llm.ai_callable() +async def verify_email( + email: Annotated[str, llm.TypeInfo(description="The collected email address")], +) -> str: + """Called to verify the user's email address.""" + if "@" not in email: + return "The email address is not valid, please confirm with the user." + + # Potentially show the email on the screen for the user to confirm + return "The email address is valid. Please confirm with the user for the spelling." + + +@llm.ai_callable() +async def update_email( + email: Annotated[str, llm.TypeInfo(description="The collected email address")], +) -> str: + """Called to update the user's email address.""" + + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + user_data["email"] = email + logger.info(f"The email is updated to {email}") + return f"The email is updated to {email}." + + +news_mailer = AgentTask( + name="news_mailer", + instructions=( + "You are a friendly assistant that can query news from the internet." + "Summarize the news in 50 words or less and ask the user if they want to receive the news by email." + "Use email_collector to collect the user's email address." + ), + functions=[query_news, send_news_email], +) + +email_collector = AgentTask( + name="email_collector", + instructions=( + "You are a friendly assistant that can collect the user's email address. Your tasks:\n" + "1. Collect the user's email address, help to complete the @ and domain part if possible.\n" + "2. Verify the address with `verify_email` function until the user confirms.\n" + "3. Update the email address after the user confirms.\n" + "Transfer back to news_mailer after the email is updated." + ), + functions=[update_email, verify_email], +) + + +async def entrypoint(ctx: JobContext): + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + + chat_log_file = "news_mailer.log" + + # Set up chat logger + chat_logger = logging.getLogger("chat_logger") + chat_logger.setLevel(logging.INFO) + handler = logging.FileHandler(chat_log_file) + formatter = logging.Formatter("%(message)s") + handler.setFormatter(formatter) + chat_logger.addHandler(handler) + + participant = await ctx.wait_for_participant() + agent = VoicePipelineAgent( + vad=ctx.proc.userdata["vad"], + stt=deepgram.STT(), + llm=openai.LLM(), + tts=openai.TTS(), + initial_task=news_mailer, + available_tasks=[news_mailer, email_collector], + max_nested_fnc_calls=3, # may call functions in the transition function + ) + + # read text input from the room for easy testing + @ctx.room.on("data_received") + def on_data_received(packet: rtc.DataPacket): + if packet.topic == "lk-chat-topic": + data = json.loads(packet.data.decode("utf-8")) + logger.debug("Text input received", extra={"message": data["message"]}) + + agent._human_input.emit( + "final_transcript", + SpeechEvent( + type=SpeechEventType.END_OF_SPEECH, + alternatives=[SpeechData(language="en", text=data["message"])], + ), + ) + + # write the chat log to a file + @agent.on("user_speech_committed") + @agent.on("agent_speech_interrupted") + @agent.on("agent_speech_committed") + def on_speech_committed(message: llm.ChatMessage): + chat_logger.info(f"{message.role}: {message.content}") + + @agent.on("function_calls_collected") + def on_function_calls_collected(calls: list[llm.FunctionCallInfo]): + fnc_infos = [{fnc.function_info.name: fnc.arguments} for fnc in calls] + chat_logger.info(f"fnc_calls_collected: {fnc_infos}") + + @agent.on("function_calls_finished") + def on_function_calls_finished(calls: list[llm.CalledFunction]): + called_infos = [{fnc.call_info.function_info.name: fnc.result} for fnc in calls] + chat_logger.info(f"fnc_calls_finished: {called_infos}") + + # Start the assistant. This will automatically publish a microphone track and listen to the participant. + agent.start(ctx.room, participant) + await agent.say("Welcome to news mailer! How may I assist you today?") + + +def prewarm_process(proc: JobProcess): + # preload silero VAD in memory to speed up session start + proc.userdata["vad"] = silero.VAD.load() + + +if __name__ == "__main__": + cli.run_app( + WorkerOptions( + entrypoint_fnc=entrypoint, + prewarm_fnc=prewarm_process, + ), + ) diff --git a/examples/voice-pipeline-agent/multi_task/restaurant_agent.py b/examples/voice-pipeline-agent/multi_task/restaurant_agent.py new file mode 100644 index 000000000..aa539e7d2 --- /dev/null +++ b/examples/voice-pipeline-agent/multi_task/restaurant_agent.py @@ -0,0 +1,354 @@ +import json +import logging +from typing import Annotated, AsyncIterable, Optional, TypedDict + +from dotenv import load_dotenv +from livekit import rtc +from livekit.agents import ( + AutoSubscribe, + JobContext, + JobProcess, + WorkerOptions, + cli, + llm, +) +from livekit.agents.pipeline import AgentCallContext, VoicePipelineAgent +from livekit.agents.pipeline.agent_task import AgentTask +from livekit.agents.stt import SpeechData, SpeechEvent, SpeechEventType +from livekit.plugins import cartesia, deepgram, openai, silero + +load_dotenv() + +logger = logging.getLogger("multi-task-agent") +logger.setLevel(logging.INFO) + + +class UserData(TypedDict): + customer_name: Optional[str] + customer_phone: Optional[str] + + reservation_time: Optional[str] + + order: Optional[list[str]] + customer_credit_card: Optional[str] + customer_credit_card_expiry: Optional[str] + customer_credit_card_cvv: Optional[str] + expense: Optional[float] + checked_out: Optional[bool] + + +def update_context( + task: AgentTask, chat_ctx: llm.ChatContext, user_data: UserData +) -> None: + injected_chat_ctx = chat_ctx.truncate(keep_last_n=4, keep_tool_calls=False) + injected_chat_ctx.append( + role="system", + text=f"Currently collected user data: {user_data}", + ) + logger.info("update task context", extra={"user_data": user_data}) + task.inject_chat_ctx(injected_chat_ctx) + + +# some common functions +@llm.ai_callable() +async def update_name( + name: Annotated[str, llm.TypeInfo(description="The customer's name")], +) -> str: + """Called when the user provides their name. + Confirm the spelling with the user before calling the function.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + user_data["customer_name"] = name + return f"The name is updated to {name}" + + +@llm.ai_callable() +async def update_phone( + phone: Annotated[str, llm.TypeInfo(description="The customer's phone number")], +) -> str: + """Called when the user provides their phone number. + Confirm the spelling with the user before calling the function.""" + + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + user_data["customer_phone"] = phone + return f"The phone number is updated to {phone}" + + +@llm.ai_callable() +async def to_greeter() -> tuple[AgentTask, str]: + """Called when user asks any unrelated questions or requests any other services not in your job description.""" + agent = AgentCallContext.get_current().agent + next_task = AgentTask.get_task("greeter") + update_context(next_task, agent.chat_ctx, agent.user_data) + return next_task, "Transferred to greeter." + + +class Greeter(AgentTask): + def __init__(self, menu: str): + super().__init__( + instructions=( + f"You are a friendly restaurant receptionist. The menu is: {menu}\n" + "Your jobs are to greet the caller and understand if they want to " + "make a reservation or order takeaway. Guide them to the right agent. " + ), + llm=openai.LLM(model="gpt-4o-mini", parallel_tool_calls=False), + ) + self.menu = menu + + @llm.ai_callable() + async def to_reservation(self) -> tuple[AgentTask, str]: + """Called when user wants to make a reservation. This function handles transitioning to the reservation agent + who will collect the necessary details like reservation time, customer name and phone number.""" + agent = AgentCallContext.get_current().agent + next_task = AgentTask.get_task("reservation") + update_context(next_task, agent.chat_ctx, agent.user_data) + return next_task, "Transferred to reservation." + + @llm.ai_callable() + async def to_takeaway(self) -> tuple[AgentTask, str]: + """Called when the user wants to place a takeaway order. This includes handling orders for pickup, + delivery, or when the user wants to proceed to checkout with their existing order.""" + agent = AgentCallContext.get_current().agent + next_task = AgentTask.get_task("takeaway") + update_context(next_task, agent.chat_ctx, agent.user_data) + return next_task, "Transferred to takeaway." + + +class Reservation(AgentTask): + def __init__(self): + super().__init__( + instructions=( + "You are a reservation agent at a restaurant. Your jobs are to ask for " + "the reservation time, then customer's name, and phone number. Then " + "confirm the reservation details with the customer." + ), + functions=[update_name, update_phone, to_greeter], + ) + + @llm.ai_callable() + async def update_reservation_time( + self, + time: Annotated[str, llm.TypeInfo(description="The reservation time")], + ) -> str: + """Called when the user provides their reservation time. + Confirm the time with the user before calling the function.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + user_data["reservation_time"] = time + return f"The reservation time is updated to {time}" + + @llm.ai_callable() + async def confirm_reservation(self) -> str: + """Called when the user confirms the reservation. + Call this function to transfer to the next step.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + if not user_data.get("customer_name") or not user_data.get("customer_phone"): + return "Please provide your name and phone number first." + + if not user_data.get("reservation_time"): + return "Please provide reservation time first." + + next_task = AgentTask.get_task("greeter") + update_context(next_task, agent.chat_ctx, agent.user_data) + return next_task, "Transferred to greeter." + + +class Takeaway(AgentTask): + def __init__(self, menu: str): + super().__init__( + instructions=( + f"Our menu is: {menu}. Your jobs are to record the order from the " + "customer. Clarify special requests and confirm the order with the " + "customer." + ), + functions=[to_greeter], + ) + + @llm.ai_callable() + async def update_order( + self, + items: Annotated[ + list[str], llm.TypeInfo(description="The items of the full order") + ], + ) -> str: + """Called when the user create or update their order.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + user_data["order"] = items + return f"Updated order to {items}" + + @llm.ai_callable() + async def to_checkout(self) -> tuple[AgentTask, str]: + """Called when the user confirms the order. Call this function to transfer to the checkout step. + Double check the order with the user before calling the function.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + if not user_data.get("order"): + return "No takeaway order found. Please make an order first." + + next_task = AgentTask.get_task("checkout") + update_context(next_task, agent.chat_ctx, agent.user_data) + return next_task, "Transferred to checkout." + + +class Checkout(AgentTask): + def __init__(self, menu: str): + super().__init__( + instructions=( + "You are a professional checkout agent at a restaurant. The menu is: " + f"{menu}. Your are responsible for confirming the expense of the " + "order and then collecting customer's name, phone number and credit card " + "information, including the card number, expiry date, and CVV step by step." + ), + functions=[update_name, update_phone, to_greeter], + ) + + @llm.ai_callable() + async def confirm_expense( + self, + expense: Annotated[float, llm.TypeInfo(description="The expense of the order")], + ) -> str: + """Called when the user confirms the expense.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + user_data["expense"] = expense + return f"The expense is confirmed to be {expense}" + + @llm.ai_callable() + async def update_credit_card( + self, + number: Annotated[str, llm.TypeInfo(description="The credit card number")], + expiry: Annotated[ + str, llm.TypeInfo(description="The expiry date of the credit card") + ], + cvv: Annotated[str, llm.TypeInfo(description="The CVV of the credit card")], + ) -> str: + """Called when the user provides their credit card number, expiry date, and CVV. + Confirm the spelling with the user before calling the function.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + user_data["customer_credit_card"] = number + user_data["customer_credit_card_expiry"] = expiry + user_data["customer_credit_card_cvv"] = cvv + return f"The credit card number is updated to {number}" + + @llm.ai_callable() + async def confirm_checkout(self) -> str: + """Called when the user confirms the checkout. + Double check the information with the user before calling the function.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + if not user_data.get("expense"): + return "Please confirm the expense first." + + if ( + not user_data.get("customer_credit_card") + or not user_data.get("customer_credit_card_expiry") + or not user_data.get("customer_credit_card_cvv") + ): + return "Please provide the credit card information first." + + user_data["checked_out"] = True + next_task = AgentTask.get_task("greeter") + update_context(next_task, agent.chat_ctx, agent.user_data) + return next_task, "Transferred to greeter." + + @llm.ai_callable() + async def to_takeaway(self) -> tuple[AgentTask, str]: + """Called when the user wants to update their order.""" + agent = AgentCallContext.get_current().agent + next_task = AgentTask.get_task("takeaway") + update_context(next_task, agent.chat_ctx, agent.user_data) + return next_task, "Transferred to takeaway." + + +async def entrypoint(ctx: JobContext): + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + + # create tasks + menu = "Pizza: $10, Salad: $5, Ice Cream: $3, Coffee: $2" + AgentTask.register_task(Greeter(menu), "greeter") + AgentTask.register_task(Reservation(), "reservation") + AgentTask.register_task(Takeaway(menu), "takeaway") + AgentTask.register_task(Checkout(menu), "checkout") + + # Set up chat logger + chat_log_file = "restaurant_agent.log" + chat_logger = logging.getLogger("chat_logger") + chat_logger.setLevel(logging.INFO) + handler = logging.FileHandler(chat_log_file) + formatter = logging.Formatter("%(message)s") + handler.setFormatter(formatter) + chat_logger.addHandler(handler) + + async def _before_tts_cb( + agent: VoicePipelineAgent, text: str | AsyncIterable[str] + ) -> str | AsyncIterable[str]: + if isinstance(text, str): + yield text.replace("*", "") + else: + async for t in text: + yield t.replace("*", "") + + participant = await ctx.wait_for_participant() + agent = VoicePipelineAgent( + vad=ctx.proc.userdata["vad"], + stt=deepgram.STT(), + llm=openai.LLM(model="gpt-4o-mini"), + tts=cartesia.TTS(), + initial_task=AgentTask.get_task("greeter"), + max_nested_fnc_calls=3, # may call functions in the transition function + before_tts_cb=_before_tts_cb, + ) + + # read text input from the room for easy testing + @ctx.room.on("data_received") + def on_data_received(packet: rtc.DataPacket): + if packet.topic == "lk-chat-topic": + data = json.loads(packet.data.decode("utf-8")) + logger.debug("Text input received", extra={"message": data["message"]}) + + agent._human_input.emit( + "final_transcript", + SpeechEvent( + type=SpeechEventType.END_OF_SPEECH, + alternatives=[SpeechData(language="en", text=data["message"])], + ), + ) + + # write the chat log to a file + @agent.on("user_speech_committed") + @agent.on("agent_speech_interrupted") + @agent.on("agent_speech_committed") + def on_speech_committed(message: llm.ChatMessage): + chat_logger.info(f"{message.role}: {message.content}") + + @agent.on("function_calls_collected") + def on_function_calls_collected(calls: list[llm.FunctionCallInfo]): + fnc_infos = [{fnc.function_info.name: fnc.arguments} for fnc in calls] + chat_logger.info(f"fnc_calls_collected: {fnc_infos}") + + @agent.on("function_calls_finished") + def on_function_calls_finished(calls: list[llm.CalledFunction]): + called_infos = [{fnc.call_info.function_info.name: fnc.result} for fnc in calls] + chat_logger.info(f"fnc_calls_finished: {called_infos}") + + # Start the assistant. This will automatically publish a microphone track and listen to the participant. + agent.start(ctx.room, participant) + await agent.say("Welcome to our restaurant! How may I assist you today?") + + +def prewarm_process(proc: JobProcess): + # preload silero VAD in memory to speed up session start + proc.userdata["vad"] = silero.VAD.load() + + +if __name__ == "__main__": + cli.run_app( + WorkerOptions( + entrypoint_fnc=entrypoint, + prewarm_fnc=prewarm_process, + ), + ) diff --git a/examples/voice-pipeline-agent/multi_task/restaurant_inline.py b/examples/voice-pipeline-agent/multi_task/restaurant_inline.py new file mode 100644 index 000000000..bc5bb6b20 --- /dev/null +++ b/examples/voice-pipeline-agent/multi_task/restaurant_inline.py @@ -0,0 +1,394 @@ +import json +import logging +from typing import Annotated, Optional, TypedDict, TypeVar + +from dotenv import load_dotenv +from livekit import rtc +from livekit.agents import ( + AutoSubscribe, + JobContext, + JobProcess, + WorkerOptions, + cli, + llm, +) +from livekit.agents.pipeline import AgentCallContext, VoicePipelineAgent +from livekit.agents.pipeline.agent_task import ( + AgentInlineTask, + AgentTask, + ResultNotSet, + TaskFailed, +) +from livekit.agents.stt import SpeechData, SpeechEvent, SpeechEventType +from livekit.plugins import cartesia, deepgram, openai, silero + +load_dotenv() + +logger = logging.getLogger("multi-task-agent") +logger.setLevel(logging.INFO) + + +class UserData(TypedDict): + customer_name: Optional[str] + customer_phone: Optional[str] + + reservation_time: Optional[str] + + order: Optional[list[str]] + customer_credit_card: Optional[str] + customer_credit_card_expiry: Optional[str] + customer_credit_card_cvv: Optional[str] + expense: Optional[float] + checked_out: Optional[bool] + + +T = TypeVar("T", bound=AgentTask) + + +def update_context( + task: T, chat_ctx: llm.ChatContext, keep_tool_calls: bool = False +) -> T: + last_chat_ctx = chat_ctx.truncate(keep_last_n=3, keep_tool_calls=keep_tool_calls) + task.inject_chat_ctx(last_chat_ctx) + return task + + +def update_instructions(instructions: str, user_data: UserData | None = None) -> str: + if user_data: + instructions += f"\nCurrently collected user data: {user_data}." + return instructions + + +class GetName(AgentInlineTask): + def __init__(self, user_data: UserData | None = None): + instructions = "Your job is to ask for and collect the user's name. Please verify the spelling before proceeding." + user_data = user_data or {} + super().__init__( + instructions=update_instructions(instructions, user_data), + preset_result=user_data.get("customer_name"), + ) + + @llm.ai_callable() + async def set_name( + self, name: Annotated[str, llm.TypeInfo(description="The user's name")] + ) -> str: + """Called when the user provides their name.""" + self._result = name + return f"The name is updated to {name}" + + +class GetPhoneNumber(AgentInlineTask): + def __init__(self, user_data: UserData | None = None): + instructions = "Your job is to collect the user's phone number. Please verify the spelling before proceeding." + user_data = user_data or {} + super().__init__( + instructions=update_instructions(instructions, user_data), + preset_result=user_data.get("customer_phone"), + ) + + @llm.ai_callable() + async def set_phone_number( + self, + phone_number: Annotated[ + str, llm.TypeInfo(description="The user's phone number") + ], + ) -> str: + """Called when the user provides their phone number.""" + # validate the phone number + phone_number = phone_number.strip().replace("-", "") + if not phone_number.isdigit() or len(phone_number) != 10: + return ( + "The phone number is not valid. Please provide a 10-digit phone number." + ) + + self._result = phone_number + return f"The phone number is updated to {phone_number}" + + +class GetReservationTime(AgentInlineTask): + def __init__(self, user_data: UserData | None = None): + instructions = "Your job is to ask for the desired reservation time and confirm the timing with the customer." + user_data = user_data or {} + super().__init__( + instructions=update_instructions(instructions, user_data), + preset_result=user_data.get("reservation_time"), + ) + + @llm.ai_callable() + async def set_reservation_time( + self, time: Annotated[str, llm.TypeInfo(description="The reservation time")] + ) -> str: + """Called when the user provides their reservation time.""" + self._result = time + return f"The reservation time is updated to {time}" + + +class TakeOrder(AgentInlineTask): + def __init__(self, menu: str, user_data: UserData | None = None): + instructions = ( + "Your job is to take the customer's order, clarify any special requests, " + f"and confirm the complete order before proceeding. Our menu is {menu}" + ) + user_data = user_data or {} + super().__init__( + instructions=update_instructions(instructions, user_data), + preset_result=user_data.get("order"), + ) + + @llm.ai_callable() + async def update_order( + self, + items: Annotated[ + list[str], llm.TypeInfo(description="The items of the full order") + ], + ) -> str: + """Called when the user updates their order.""" + self._result = items + if not items: + return "All items are removed from the order." + + return f"Updated order to {items}" + + +class GetCreditCard(AgentInlineTask): + def __init__(self, user_data: UserData | None = None): + instructions = "Your job is to collect the customer's payment information: card number, expiration date (MM/YY), and CVV." + user_data = user_data or {} + credit_card = { + "customer_credit_card": user_data.get("customer_credit_card"), + "customer_credit_card_expiry": user_data.get("customer_credit_card_expiry"), + "customer_credit_card_cvv": user_data.get("customer_credit_card_cvv"), + } + super().__init__( + instructions=update_instructions(instructions, user_data), + preset_result=credit_card, + ) + + @llm.ai_callable() + async def set_credit_card( + self, + number: Annotated[str, llm.TypeInfo(description="The credit card number")], + expiry: Annotated[ + str, + llm.TypeInfo( + description="The expiry date of the credit card, in MM/YY format" + ), + ], + cvv: Annotated[str, llm.TypeInfo(description="The CVV of the credit card")], + ) -> str: + """Called when the user provides their credit card information.""" + + # validate the credit card information + if not cvv.isdigit() or len(cvv) != 3: + return "The CVV is not valid. Please provide a 3-digit CVV." + + # validate the expiry date + month, year = expiry.split("/") + if ( + not month.isdigit() + or not year.isdigit() + or len(month) != 2 + or len(year) != 2 + ): + return "The expiry date is not valid." + + self._result = { + "customer_credit_card": number, + "customer_credit_card_expiry": expiry, + "customer_credit_card_cvv": cvv, + } + return f"The credit card information is updated to {self._result}" + + +class HostBot(AgentTask): + def __init__(self, menu: str): + super().__init__( + instructions=( + f"You are a friendly restaurant host. Our menu: {menu}\n" + "Welcome customers and guide them to either make, update or cancel a reservation, " + "or order takeaway and then checkout based on their preference." + ) + ) + self.menu = menu + + @llm.ai_callable() + async def make_reservation(self) -> str: + """Called when the user want to make or update a reservation.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + + try: + reservation_time = await update_context( + GetReservationTime(user_data), agent.chat_ctx + ).run() + user_data["reservation_time"] = reservation_time + + name = await update_context(GetName(user_data), agent.chat_ctx).run() + user_data["customer_name"] = name + + phone = await update_context( + GetPhoneNumber(user_data), agent.chat_ctx + ).run() + user_data["customer_phone"] = phone + + except TaskFailed as e: + return f"Task failed: {e}" + except ResultNotSet: + return f"Failed to collect user data, the collected data is {user_data}" + + return f"Reservation successful. Updated user data: {user_data}" + + @llm.ai_callable() + async def cancel_reservation(self) -> str: + """Called when the user wants to cancel the reservation.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + if "reservation_time" not in user_data: + return "You have not made a reservation yet." + + user_data["reservation_time"] = None + return f"Reservation cancelled. Updated user data: {user_data}" + + @llm.ai_callable() + async def order_takeaway(self) -> str: + """Called when the user wants to order takeaway.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + + try: + order = await update_context( + TakeOrder(self.menu, user_data), agent.chat_ctx + ).run() + user_data["order"] = order + except TaskFailed as e: + return f"Task failed: {e}" + except ResultNotSet: + return f"Failed to collect the order, the collected data is {user_data}" + return f"Order successful. Updated user data: {user_data}" + + @llm.ai_callable() + async def checkout( + self, + expense: Annotated[float, llm.TypeInfo(description="The expense of the order")], + ) -> str: + """Called when the user confirms the expense of the order and want to checkout.""" + agent = AgentCallContext.get_current().agent + user_data: UserData = agent.user_data + user_data["expense"] = expense + + try: + name = await update_context(GetName(user_data), agent.chat_ctx).run() + user_data["customer_name"] = name + + phone = await update_context( + GetPhoneNumber(user_data), agent.chat_ctx + ).run() + user_data["customer_phone"] = phone + + credit_card = await update_context( + GetCreditCard(user_data), agent.chat_ctx + ).run() + if not isinstance(credit_card, dict) or not all( + key in credit_card + for key in [ + "customer_credit_card", + "customer_credit_card_expiry", + "customer_credit_card_cvv", + ] + ): + return "The credit card information is not valid." + + user_data.update(credit_card) + except TaskFailed as e: + return f"Task failed: {e}" + except ResultNotSet: + return f"Failed to collect user data, the collected data is {user_data}" + + user_data["checked_out"] = True + + return f"Updated user data: {user_data}. User checked out." + + +@llm.ai_callable() +async def to_host() -> tuple[AgentTask, str]: + """Called when user asks unrelated questions or requests other services.""" + agent = AgentCallContext.get_current().agent + next_task = AgentTask.get_task(HostBot) + return update_context(next_task, agent.chat_ctx), f"User data: {agent.user_data}" + + +async def entrypoint(ctx: JobContext): + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + + # create tasks + menu = "Pizza: $10, Salad: $5, Ice Cream: $3, Coffee: $2" + greeter = AgentTask.register_task(HostBot(menu)) + + # Set up chat logger + chat_log_file = "restaurant_agent.log" + chat_logger = logging.getLogger("chat_logger") + chat_logger.setLevel(logging.INFO) + handler = logging.FileHandler(chat_log_file) + formatter = logging.Formatter("%(message)s") + handler.setFormatter(formatter) + chat_logger.addHandler(handler) + + participant = await ctx.wait_for_participant() + agent = VoicePipelineAgent( + vad=ctx.proc.userdata["vad"], + stt=deepgram.STT(), + llm=openai.LLM(), + tts=cartesia.TTS(), + initial_task=greeter, + max_nested_fnc_calls=3, # may call functions in the transition function + ) + + # read text input from the room for easy testing + @ctx.room.on("data_received") + def on_data_received(packet: rtc.DataPacket): + if packet.topic == "lk-chat-topic": + data = json.loads(packet.data.decode("utf-8")) + logger.debug("Text input received", extra={"message": data["message"]}) + + agent._human_input.emit( + "final_transcript", + SpeechEvent( + type=SpeechEventType.END_OF_SPEECH, + alternatives=[SpeechData(language="en", text=data["message"])], + ), + ) + + # write the chat log to a file + @agent.on("user_speech_committed") + @agent.on("agent_speech_interrupted") + @agent.on("agent_speech_committed") + def on_speech_committed(message: llm.ChatMessage): + chat_logger.info(f"{message.role}: {message.content}") + + @agent.on("function_calls_collected") + def on_function_calls_collected(calls: list[llm.FunctionCallInfo]): + fnc_infos = [{fnc.function_info.name: fnc.arguments} for fnc in calls] + chat_logger.info(f"fnc_calls_collected: {fnc_infos}") + + @agent.on("function_calls_finished") + def on_function_calls_finished(calls: list[llm.CalledFunction]): + called_infos = [{fnc.call_info.function_info.name: fnc.result} for fnc in calls] + chat_logger.info(f"fnc_calls_finished: {called_infos}") + + # Start the assistant. This will automatically publish a microphone track and listen to the participant. + agent.start(ctx.room, participant) + await agent.say("Welcome to our restaurant! How may I assist you today?") + + +def prewarm_process(proc: JobProcess): + # preload silero VAD in memory to speed up session start + proc.userdata["vad"] = silero.VAD.load() + + +if __name__ == "__main__": + cli.run_app( + WorkerOptions( + entrypoint_fnc=entrypoint, + prewarm_fnc=prewarm_process, + ), + ) diff --git a/livekit-agents/livekit/agents/llm/chat_context.py b/livekit-agents/livekit/agents/llm/chat_context.py index ccde86bba..79f7cc57a 100644 --- a/livekit-agents/livekit/agents/llm/chat_context.py +++ b/livekit-agents/livekit/agents/llm/chat_context.py @@ -113,7 +113,7 @@ def create_tool_from_called_function( tool_exception: Exception | None = None try: - content = called_function.task.result() + content = called_function.get_content() except BaseException as e: if isinstance(e, Exception): tool_exception = e @@ -176,6 +176,12 @@ def copy(self): copied_msg._metadata = self._metadata return copied_msg + @property + def is_tool_call(self) -> bool: + return self.role == "tool" or ( + self.role == "assistant" and bool(self.tool_calls) + ) + @dataclass class ChatContext: @@ -192,3 +198,31 @@ def copy(self) -> ChatContext: copied_chat_ctx = ChatContext(messages=[m.copy() for m in self.messages]) copied_chat_ctx._metadata = self._metadata return copied_chat_ctx + + def truncate( + self, + keep_last_n: int, + *, + keep_system_message: bool = False, + keep_tool_calls: bool = True, + ) -> ChatContext: + def _keep_message(msg: ChatMessage) -> bool: + if not keep_tool_calls and msg.is_tool_call: + return False + if not keep_system_message and msg.role == "system": + return False + return True + + messages = [msg for msg in self.messages if _keep_message(msg)] + + start = 0 if keep_last_n <= 0 else len(messages) - keep_last_n + copied_messages = [msg.copy() for msg in messages[start:]] + + if keep_tool_calls: + # tool message at the first position is invalid + while copied_messages and copied_messages[0].role == "tool": + copied_messages.pop(0) + + new_ctx = ChatContext(messages=copied_messages) + new_ctx._metadata = self._metadata + return new_ctx diff --git a/livekit-agents/livekit/agents/llm/function_context.py b/livekit-agents/livekit/agents/llm/function_context.py index 59604fc8d..676878004 100644 --- a/livekit-agents/livekit/agents/llm/function_context.py +++ b/livekit-agents/livekit/agents/llm/function_context.py @@ -22,15 +22,22 @@ import types import typing from dataclasses import dataclass -from typing import Any, Callable, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple from ..log import logger +if TYPE_CHECKING: + from ..pipeline.agent_task import AgentTask + class _UseDocMarker: pass +class _NoMetadataError(Exception): + pass + + METADATA_ATTR = "__livekit_ai_metadata__" USE_DOCSTRING = _UseDocMarker() @@ -101,6 +108,32 @@ class CalledFunction: result: Any | None = None exception: BaseException | None = None + def get_agent_task(self) -> "AgentTask" | None: + from ..pipeline.agent_task import AgentTask + + assert self.task.done() + + if isinstance(self.result, tuple): + assert len(self.result) == 2 and isinstance(self.result[0], AgentTask) + return self.result[0] + elif isinstance(self.result, AgentTask): + return self.result + return None + + def get_content(self) -> Any | None: + from ..pipeline.agent_task import AgentTask + + assert self.task.done() + + if self.exception: + return f"Error: {self.exception}" + if isinstance(self.result, tuple): + assert len(self.result) == 2 and isinstance(self.result[1], str) + return self.result[1] + elif not isinstance(self.result, AgentTask): + return self.result + return None + def ai_callable( *, @@ -136,16 +169,13 @@ def deco(f): return deco - def _register_ai_function(self, fnc: Callable) -> None: + @staticmethod + def _callable_to_fnc_info(fnc: Callable) -> FunctionInfo: if not hasattr(fnc, METADATA_ATTR): - logger.warning(f"function {fnc.__name__} does not have ai metadata") - return + raise _NoMetadataError("function must be decorated with ai_callable") metadata: _AIFncMetadata = getattr(fnc, METADATA_ATTR) fnc_name = metadata.name - if fnc_name in self._fncs: - raise ValueError(f"duplicate ai_callable name: {fnc_name}") - sig = inspect.signature(fnc) # get_type_hints with include_extra=True is needed when using Annotated @@ -190,7 +220,7 @@ def _register_ai_function(self, fnc: Callable) -> None: choices=choices, ) - self._fncs[metadata.name] = FunctionInfo( + return FunctionInfo( name=metadata.name, description=metadata.description, auto_retry=metadata.auto_retry, @@ -198,10 +228,27 @@ def _register_ai_function(self, fnc: Callable) -> None: arguments=args, ) + def _register_ai_function(self, fnc: Callable) -> None: + try: + fnc_info = self._callable_to_fnc_info(fnc) + except _NoMetadataError: + logger.warning(f"function {fnc.__name__} does not have ai metadata") + return + + if fnc_info.name in self._fncs: + raise ValueError(f"duplicate ai_callable name: {fnc_info.name}") + + self._fncs[fnc_info.name] = fnc_info + @property def ai_functions(self) -> dict[str, FunctionInfo]: return self._fncs + def copy(self) -> "FunctionContext": + new_fnc_ctx = FunctionContext() + new_fnc_ctx._fncs.update(self._fncs) + return new_fnc_ctx + @dataclass(frozen=True) class _AIFncMetadata: diff --git a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py index 5599b2e93..cdfa3d992 100644 --- a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py +++ b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py @@ -19,6 +19,7 @@ from livekit.agents import llm, stt, tokenize, transcription, utils, vad from livekit.agents.llm import ChatMessage from livekit.agents.metrics import MultimodalLLMMetrics +from livekit.agents.pipeline.agent_task import AgentTask from ..log import logger from ..types import ATTRIBUTE_AGENT_STATE, AgentState @@ -72,6 +73,7 @@ def session( *, chat_ctx: llm.ChatContext | None = None, fnc_ctx: llm.FunctionContext | None = None, + init_task: AgentTask | None = None, ) -> _RealtimeAPISession: """ Create a new realtime session with the given chat and function contexts. @@ -143,7 +145,8 @@ def __init__( model: _RealtimeAPI, vad: vad.VAD | None = None, chat_ctx: llm.ChatContext | None = None, - fnc_ctx: llm.FunctionContext | None = None, + # fnc_ctx: llm.FunctionContext | None = None, + initial_task: AgentTask | None = None, transcription: AgentTranscriptionOptions = AgentTranscriptionOptions(), max_text_response_retries: int = 5, loop: asyncio.AbstractEventLoop | None = None, @@ -170,7 +173,8 @@ def __init__( self._model = model self._vad = vad self._chat_ctx = chat_ctx - self._fnc_ctx = fnc_ctx + self._initial_task = initial_task or AgentTask() + # self._fnc_ctx = fnc_ctx self._opts = _ImplOptions( transcription=transcription, @@ -235,7 +239,7 @@ def start( break self._session = self._model.session( - chat_ctx=self._chat_ctx, fnc_ctx=self._fnc_ctx + chat_ctx=self._chat_ctx, init_task=self._initial_task ) # Create a task to wait for initialization and start the main task diff --git a/livekit-agents/livekit/agents/pipeline/__init__.py b/livekit-agents/livekit/agents/pipeline/__init__.py index 480dd7990..51b41593a 100644 --- a/livekit-agents/livekit/agents/pipeline/__init__.py +++ b/livekit-agents/livekit/agents/pipeline/__init__.py @@ -1,3 +1,4 @@ +from .agent_task import AgentTask from .pipeline_agent import ( AgentCallContext, AgentTranscriptionOptions, @@ -8,4 +9,5 @@ "VoicePipelineAgent", "AgentCallContext", "AgentTranscriptionOptions", + "AgentTask", ] diff --git a/livekit-agents/livekit/agents/pipeline/agent_output.py b/livekit-agents/livekit/agents/pipeline/agent_output.py index 14a836ef7..2deb80466 100644 --- a/livekit-agents/livekit/agents/pipeline/agent_output.py +++ b/livekit-agents/livekit/agents/pipeline/agent_output.py @@ -6,7 +6,7 @@ from livekit import rtc -from .. import llm, tokenize, utils +from .. import tokenize, utils from .. import transcription as agent_transcription from .. import tts as text_to_speech from .agent_playout import AgentPlayout, PlayoutHandle @@ -96,15 +96,9 @@ def __init__( *, room: rtc.Room, agent_playout: AgentPlayout, - llm: llm.LLM, tts: text_to_speech.TTS, ) -> None: - self._room, self._agent_playout, self._llm, self._tts = ( - room, - agent_playout, - llm, - tts, - ) + self._room, self._agent_playout, self._tts = room, agent_playout, tts self._tasks = set[asyncio.Task[Any]]() @property diff --git a/livekit-agents/livekit/agents/pipeline/agent_task.py b/livekit-agents/livekit/agents/pipeline/agent_task.py new file mode 100644 index 000000000..ffec81a3c --- /dev/null +++ b/livekit-agents/livekit/agents/pipeline/agent_task.py @@ -0,0 +1,243 @@ +import asyncio +import inspect +import logging +from typing import Annotated, Any, Callable, Dict, Optional, Type, Union + +from ..llm import LLM, ChatContext, FunctionContext +from ..llm.function_context import METADATA_ATTR, TypeInfo, ai_callable +from ..stt import STT +from .speech_handle import SpeechHandle + +logger = logging.getLogger(__name__) + + +class ResultNotSet(Exception): + """Exception raised when the task result is not set.""" + + +class TaskFailed(Exception): + """Exception raised when the task fails.""" + + +class SilentSentinel: + """Sentinel value to indicate the function call shouldn't create a response.""" + + def __init__(self, result: Any = None, error: Optional[BaseException] = None): + self._result = result + self._error = error + + def __repr__(self) -> str: + return f"SilentSentinel(result={self._result}, error={self._error})" + + +class AgentTask: + # Single class-level storage for all tasks + _registered_tasks: Dict[Union[str, Type["AgentTask"]], "AgentTask"] = {} + + def __init__( + self, + instructions: Optional[str] = None, + functions: Optional[list[Callable]] = None, + llm: Optional[LLM] = None, + name: Optional[str] = None, + ) -> None: + self._chat_ctx = ChatContext() + self._instructions = instructions + if instructions: + self._chat_ctx.append(text=instructions, role="system") + + self._fnc_ctx = FunctionContext() + functions = functions or [] + # register ai functions from the list + for fnc in functions: + if not hasattr(fnc, METADATA_ATTR): + fnc = ai_callable()(fnc) + self._fnc_ctx._register_ai_function(fnc) + + # register ai functions from the class + for _, member in inspect.getmembers(self, predicate=inspect.ismethod): + if hasattr(member, METADATA_ATTR) and member not in functions: + self._fnc_ctx._register_ai_function(member) + + self._llm = llm + self._stt = None + + # Auto-register if name is provided + if name is not None: + self.register_task(self, name) + self._name = name + + @classmethod + def register_task( + cls, task: "AgentTask", name: Optional[str] = None + ) -> "AgentTask": + """Register a task instance globally""" + # Register by name if provided + if name is not None: + if name in cls._registered_tasks: + raise ValueError(f"Task with name '{name}' already registered") + cls._registered_tasks[name] = task + else: + # register by type + task_type = type(task) + if task_type in cls._registered_tasks: + raise ValueError( + f"Task of type {task_type.__name__} already registered" + ) + cls._registered_tasks[task_type] = task + + return task + + def inject_chat_ctx(self, chat_ctx: ChatContext) -> None: + # filter duplicate messages + existing_messages = {msg.id: msg for msg in self._chat_ctx.messages} + for msg in chat_ctx.messages: + if msg.id not in existing_messages: + self._chat_ctx.messages.append(msg) + + @property + def instructions(self) -> Optional[str]: + return self._instructions + + @property + def chat_ctx(self) -> ChatContext: + return self._chat_ctx + + @chat_ctx.setter + def chat_ctx(self, chat_ctx: ChatContext) -> None: + self._chat_ctx = chat_ctx + + @property + def fnc_ctx(self) -> FunctionContext: + return self._fnc_ctx + + @property + def llm(self) -> Optional[LLM]: + return self._llm + + @property + def stt(self) -> Optional[STT]: + return self._stt + + @classmethod + def get_task(cls, key: Union[str, Type["AgentTask"]]) -> "AgentTask": + """Get task instance by name or class""" + if key not in cls._registered_tasks: + raise ValueError(f"Task with name or class {key} not found") + return cls._registered_tasks[key] + + @classmethod + def all_registered_tasks(cls) -> list["AgentTask"]: + """Get all registered tasks""" + return list(set(cls._registered_tasks.values())) + + def __repr__(self) -> str: + if self._name: + return f"{self.__class__.__name__}(name={self._name})" + return f"{self.__class__.__name__}()" + + +class AgentInlineTask(AgentTask): + def __init__( + self, + instructions: Optional[str] = None, + functions: Optional[list[Callable]] = None, + llm: Optional[LLM] = None, + name: Optional[str] = None, + preset_result: Optional[Any] = None, + ) -> None: + super().__init__(instructions, functions, llm, name) + + self._done_fut: asyncio.Future[None] = asyncio.Future() + self._result: Optional[Any] = preset_result + + self._parent_task: Optional[AgentTask] = None + self._parent_speech: Optional[SpeechHandle] = None + + async def run(self, proactive_reply: bool = True) -> Any: + from ..pipeline.pipeline_agent import AgentCallContext + + call_ctx = AgentCallContext.get_current() + agent = call_ctx.agent + + self._parent_task = agent.current_agent_task + self._parent_speech = call_ctx.speech_handle + agent.update_task(self) + logger.debug( + "running inline task", + extra={"task": str(self), "parent_task": str(self._parent_task)}, + ) + try: + # generate reply to the user + if proactive_reply: + speech_handle = SpeechHandle.create_assistant_speech( + allow_interruptions=agent._opts.allow_interruptions, + add_to_chat_ctx=True, + ) + self._proactive_reply_task = asyncio.create_task( + agent._synthesize_answer_task(None, speech_handle) + ) + if self._parent_speech is not None: + self._parent_speech.add_nested_speech(speech_handle) + else: + agent._add_speech_for_playout(speech_handle) + + # wait for the task to complete + await self._done_fut + if self.exception: + raise self.exception + + if self._result is None: + raise ResultNotSet() + return self._result + finally: + # reset the parent task + agent.update_task(self._parent_task) + logger.debug( + "inline task completed", + extra={ + "result": self._result, + "error": self.exception, + "task": str(self), + "parent_task": str(self._parent_task), + }, + ) + + @ai_callable() + def on_success(self) -> SilentSentinel: + """Called when user confirms the information is correct. + This function is called to indicate the job is done. + """ + if not self._done_fut.done(): + self._done_fut.set_result(None) + return SilentSentinel(result=self.result, error=self.exception) + + @ai_callable() + def on_error( + self, + reason: Annotated[str, TypeInfo(description="The reason for the error")], + ) -> SilentSentinel: + """Called when user wants to stop or refuses to provide the information. + Only focus on your job. + """ + if not self._done_fut.done(): + self._done_fut.set_exception(TaskFailed(reason)) + return SilentSentinel(result=self.result, error=self.exception) + + @property + def done(self) -> bool: + return self._done_fut.done() + + @property + def result(self) -> Any: + return self._result + + @property + def exception(self) -> Optional[BaseException]: + return self._done_fut.exception() + + def __repr__(self) -> str: + speech_id = self._parent_speech.id if self._parent_speech else None + return ( + f"{self.__class__.__name__}(parent_speech={speech_id}, name={self._name})" + ) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index e67af5b9f..4931e9fd3 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -19,10 +19,18 @@ from livekit import rtc from .. import metrics, stt, tokenize, tts, utils, vad -from ..llm import LLM, ChatContext, ChatMessage, FunctionContext, LLMStream +from ..llm import ( + LLM, + CalledFunction, + ChatContext, + ChatMessage, + FunctionContext, + LLMStream, +) from ..types import ATTRIBUTE_AGENT_STATE, AgentState from .agent_output import AgentOutput, SpeechSource, SynthesisHandle from .agent_playout import AgentPlayout +from .agent_task import AgentInlineTask, AgentTask, SilentSentinel from .human_input import HumanInput from .log import logger from .plotter import AssistantPlotter @@ -65,10 +73,16 @@ class AgentCallContext: - def __init__(self, assistant: "VoicePipelineAgent", llm_stream: LLMStream) -> None: + def __init__( + self, + assistant: "VoicePipelineAgent", + llm_stream: LLMStream, + speech_handle: SpeechHandle, + ) -> None: self._assistant = assistant self._metadata = dict[str, Any]() self._llm_stream = llm_stream + self._speech_handle = speech_handle self._extra_chat_messages: list[ChatMessage] = [] @staticmethod @@ -92,6 +106,10 @@ def get_metadata(self, key: str, default: Any = None) -> Any: def llm_stream(self) -> LLMStream: return self._llm_stream + @property + def speech_handle(self) -> SpeechHandle: + return self._speech_handle + def add_extra_chat_message(self, message: ChatMessage) -> None: """Append chat message to the end of function outputs for the answer LLM call""" self._extra_chat_messages.append(message) @@ -185,8 +203,9 @@ def __init__( llm: LLM, tts: tts.TTS, turn_detector: _TurnDetector | None = None, - chat_ctx: ChatContext | None = None, - fnc_ctx: FunctionContext | None = None, + # chat_ctx: ChatContext | None = None, + # fnc_ctx: FunctionContext | None = None, + initial_task: AgentTask | None = None, allow_interruptions: bool = True, interrupt_speech_duration: float = 0.5, interrupt_min_words: int = 0, @@ -278,8 +297,8 @@ def __init__( self._stt, self._vad, self._llm, self._tts = stt, vad, llm, tts self._turn_detector = turn_detector - self._chat_ctx = chat_ctx or ChatContext() - self._fnc_ctx = fnc_ctx + # self._chat_ctx = chat_ctx or ChatContext() + # self._fnc_ctx = fnc_ctx self._started, self._closed = False, False self._human_input: HumanInput | None = None @@ -310,13 +329,29 @@ def __init__( self._last_final_transcript_time: float | None = None self._last_speech_time: float | None = None + # agent tasks + self._user_data: dict[str, Any] = {} + self._current_agent_task = initial_task or AgentTask() + @property - def fnc_ctx(self) -> FunctionContext | None: - return self._fnc_ctx + def user_data(self) -> dict[str, Any]: + return self._user_data - @fnc_ctx.setter - def fnc_ctx(self, fnc_ctx: FunctionContext | None) -> None: - self._fnc_ctx = fnc_ctx + @property + def current_agent_task(self) -> AgentTask: + return self._current_agent_task + + def update_task(self, task: AgentTask) -> None: + self._current_agent_task = task + + @property + def fnc_ctx(self) -> FunctionContext: + return self._current_agent_task.fnc_ctx + + @property + def _chat_ctx(self) -> ChatContext: + # for compatibility for self._chat_ctx + return self._current_agent_task.chat_ctx @property def chat_ctx(self) -> ChatContext: @@ -324,7 +359,7 @@ def chat_ctx(self) -> ChatContext: @property def llm(self) -> LLM: - return self._llm + return self._current_agent_task.llm or self._llm @property def tts(self) -> tts.TTS: @@ -387,6 +422,10 @@ def _on_llm_metrics(llm_metrics: metrics.LLMMetrics) -> None: ), ) + # for agent_task in self._agent_tasks: + # if agent_task.llm: + # agent_task.llm.on("metrics_collected", _on_llm_metrics) + @self._vad.on("metrics_collected") def _on_vad_metrics(vad_metrics: vad.VADMetrics) -> None: self.emit( @@ -652,10 +691,7 @@ async def _main_task(self) -> None: agent_playout = AgentPlayout(audio_source=audio_source) self._agent_output = AgentOutput( - room=self._room, - agent_playout=agent_playout, - llm=self._llm, - tts=self._tts, + room=self._room, agent_playout=agent_playout, tts=self._tts ) def _on_playout_started() -> None: @@ -879,32 +915,35 @@ async def _execute_function_calls() -> None: if not is_using_tools or interrupted: return + assert isinstance(speech_handle.source, LLMStream) + assert not user_question or speech_handle.user_committed, ( + "user speech should have been committed before using tools" + ) + + llm_stream = speech_handle.source + if speech_handle.fnc_nested_depth >= self._opts.max_nested_fnc_calls: logger.warning( "max function calls nested depth reached", extra={ "speech_id": speech_handle.id, "fnc_nested_depth": speech_handle.fnc_nested_depth, + "fnc_names": [ + fnc.function_info.name for fnc in llm_stream.function_calls + ], }, ) return - assert isinstance(speech_handle.source, LLMStream) - assert not user_question or speech_handle.user_committed, ( - "user speech should have been committed before using tools" - ) - - llm_stream = speech_handle.source - # execute functions - call_ctx = AgentCallContext(self, llm_stream) + call_ctx = AgentCallContext(self, llm_stream, speech_handle) tk = _CallContextVar.set(call_ctx) new_function_calls = llm_stream.function_calls self.emit("function_calls_collected", new_function_calls) - called_fncs = [] + called_fncs: list[CalledFunction] = [] for fnc in new_function_calls: called_fnc = fnc.execute() called_fncs.append(called_fnc) @@ -928,19 +967,40 @@ async def _execute_function_calls() -> None: ) tool_calls_info = [] - tool_calls_results = [] - + tool_calls_results: list[ChatMessage] = [] + tool_calls_chat_ctx = call_ctx.chat_ctx + should_create_response = True + original_task = self.current_agent_task for called_fnc in called_fncs: # ignore the function calls that returns None if called_fnc.result is None and called_fnc.exception is None: continue + if isinstance(called_fnc.result, SilentSentinel): + should_create_response = False + continue + + new_task = called_fnc.get_agent_task() + if new_task: + logger.debug( + "switching to next agent task", + extra={ + "new_task": str(new_task), + "previous_task": str(self.current_agent_task), + }, + ) + self.update_task(new_task) + # TODO: should we update task after the function call is done? + # use the new chat ctx for the next task + tool_calls_chat_ctx = self._chat_ctx + tool_calls_info.append(called_fnc.call_info) tool_calls_results.append( ChatMessage.create_tool_from_called_function(called_fnc) ) - if not tool_calls_info: + if not tool_calls_info or not should_create_response: + self.emit("function_calls_finished", called_fncs) return # create a nested speech handle @@ -949,6 +1009,11 @@ async def _execute_function_calls() -> None: ] extra_tools_messages.extend(tool_calls_results) + if original_task != self.current_agent_task: + # add the function call results to the original task + original_task.chat_ctx.messages.extend(extra_tools_messages) + extra_tools_messages = [] + new_speech_handle = SpeechHandle.create_tool_speech( allow_interruptions=speech_handle.allow_interruptions, add_to_chat_ctx=speech_handle.add_to_chat_ctx, @@ -958,28 +1023,27 @@ async def _execute_function_calls() -> None: ) # synthesize the tool speech with the chat ctx from llm_stream - chat_ctx = call_ctx.chat_ctx.copy() + chat_ctx = tool_calls_chat_ctx.copy() chat_ctx.messages.extend(extra_tools_messages) chat_ctx.messages.extend(call_ctx.extra_chat_messages) - fnc_ctx = self.fnc_ctx + fnc_ctx: Optional[FunctionContext] = self.fnc_ctx if ( fnc_ctx + and len(fnc_ctx.ai_functions) > 0 and new_speech_handle.fnc_nested_depth >= self._opts.max_nested_fnc_calls ): if len(fnc_ctx.ai_functions) > 1: logger.info( - "max function calls nested depth reached, dropping function context. increase max_nested_fnc_calls to enable additional nesting.", + "max function calls nested depth reached, dropping function context. " + "increase max_nested_fnc_calls to enable additional nesting.", extra={ "speech_id": speech_handle.id, "fnc_nested_depth": speech_handle.fnc_nested_depth, }, ) fnc_ctx = None - answer_llm_stream = self._llm.chat( - chat_ctx=chat_ctx, - fnc_ctx=fnc_ctx, - ) + answer_llm_stream = self.llm.chat(chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) synthesis_handle = self._synthesize_agent_speech( new_speech_handle.id, answer_llm_stream @@ -1093,7 +1157,11 @@ async def _llm_stream_to_str_generator( def _validate_reply_if_possible(self) -> None: """Check if the new agent speech should be played""" - if self._playing_speech and not self._playing_speech.interrupted: + if ( + self._playing_speech + and not self._playing_speech.interrupted + and not self._inline_task_running() + ): should_ignore_input = False if not self._playing_speech.allow_interruptions: should_ignore_input = True @@ -1124,14 +1192,15 @@ def _validate_reply_if_possible(self) -> None: assert self._pending_agent_reply is not None - # due to timing, we could end up with two pushed agent replies inside the speech queue. - # so make sure we directly interrupt every reply when validating a new one - for speech in self._speech_q: - if not speech.is_reply: - continue + if not self._inline_task_running(): + # due to timing, we could end up with two pushed agent replies inside the speech queue. + # so make sure we directly interrupt every reply when validating a new one + for speech in self._speech_q: + if not speech.is_reply: + continue - if speech.allow_interruptions: - speech.interrupt() + if speech.allow_interruptions: + speech.interrupt() logger.debug( "validated agent reply", @@ -1155,7 +1224,11 @@ def _validate_reply_if_possible(self) -> None: ) self.emit("metrics_collected", eou_metrics) - self._add_speech_for_playout(self._pending_agent_reply) + if self._playing_speech and not self._playing_speech.nested_speech_done: + self._playing_speech.add_nested_speech(self._pending_agent_reply) + else: + self._add_speech_for_playout(self._pending_agent_reply) + self._pending_agent_reply = None self._transcribed_interim_text = "" # self._transcribed_text is reset after MIN_TIME_PLAYED_FOR_COMMIT, see self._play_speech @@ -1166,7 +1239,7 @@ def _interrupt_if_possible(self) -> None: self._playing_speech.interrupt() def _should_interrupt(self) -> bool: - if self._playing_speech is None: + if self._playing_speech is None or self._inline_task_running(): return False if ( @@ -1183,6 +1256,15 @@ def _should_interrupt(self) -> bool: return True + def _inline_task_running(self) -> bool: + if ( + not isinstance(self.current_agent_task, AgentInlineTask) + or self._playing_speech is None + ): + return False + + return self.current_agent_task._parent_speech is self._playing_speech + def _add_speech_for_playout(self, speech_handle: SpeechHandle) -> None: self._speech_q.append(speech_handle) self._speech_q_changed.set() diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py index fbb453609..f09772e16 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py @@ -3,6 +3,7 @@ DEFAULT_INPUT_AUDIO_TRANSCRIPTION, DEFAULT_SERVER_VAD_OPTIONS, InputTranscriptionOptions, + RealtimeCallContext, RealtimeContent, RealtimeError, RealtimeModel, @@ -30,4 +31,5 @@ "api_proto", "DEFAULT_INPUT_AUDIO_TRANSCRIPTION", "DEFAULT_SERVER_VAD_OPTIONS", + "RealtimeCallContext", ] diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index 8b6b717f7..ab6685ca9 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -2,11 +2,12 @@ import asyncio import base64 +import contextvars import os import time from copy import deepcopy from dataclasses import dataclass -from typing import AsyncIterable, Literal, Optional, Union, cast, overload +from typing import Any, AsyncIterable, Literal, Optional, Union, cast, overload from urllib.parse import urlencode import aiohttp @@ -14,6 +15,7 @@ from livekit.agents import llm, utils from livekit.agents.llm.function_context import _create_ai_function_info from livekit.agents.metrics import MultimodalLLMError, MultimodalLLMMetrics +from livekit.agents.pipeline import AgentTask from livekit.agents.types import NOT_GIVEN, NotGivenOr from typing_extensions import TypedDict @@ -42,6 +44,28 @@ ] +_CallContextVar = contextvars.ContextVar["RealtimeCallContext"]( + "realtime_session_contextvar" +) + + +class RealtimeCallContext: + def __init__( + self, + session: "RealtimeSession", + ) -> None: + self._session = session + self._metadata = dict[str, Any]() + + @staticmethod + def get_current() -> "RealtimeCallContext": + return _CallContextVar.get() + + @property + def session(self) -> "RealtimeSession": + return self._session + + @dataclass class InputTranscriptionCompleted: item_id: str @@ -445,7 +469,8 @@ def session( self, *, chat_ctx: llm.ChatContext | None = None, - fnc_ctx: llm.FunctionContext | None = None, + # fnc_ctx: llm.FunctionContext | None = None, + init_task: AgentTask | None = None, modalities: list[api_proto.Modality] | None = None, instructions: str | None = None, voice: api_proto.Voice | None = None, @@ -483,7 +508,8 @@ def session( new_session = RealtimeSession( chat_ctx=chat_ctx or llm.ChatContext(), - fnc_ctx=fnc_ctx, + # fnc_ctx=fnc_ctx, + init_task=init_task, opts=opts, http_session=self._ensure_session(), loop=self._loop, @@ -809,7 +835,8 @@ def __init__( opts: _ModelOptions, http_session: aiohttp.ClientSession, chat_ctx: llm.ChatContext, - fnc_ctx: llm.FunctionContext | None, + # fnc_ctx: llm.FunctionContext | None, + init_task: AgentTask | None = None, loop: asyncio.AbstractEventLoop, ) -> None: super().__init__() @@ -825,7 +852,10 @@ def __init__( self._item_deleted_futs: dict[str, asyncio.Future[bool]] = {} self._item_truncated_futs: dict[str, asyncio.Future[bool]] = {} - self._fnc_ctx = fnc_ctx + # self._fnc_ctx = fnc_ctx + self._current_task = init_task or AgentTask() + self._user_data: dict[str, Any] = {} + self._loop = loop self._opts = opts @@ -853,11 +883,25 @@ async def aclose(self) -> None: @property def fnc_ctx(self) -> llm.FunctionContext | None: - return self._fnc_ctx + return self._current_task.fnc_ctx + + @property + def user_data(self) -> dict[str, Any]: + return self._user_data + + @property + def current_task(self) -> AgentTask: + return self._current_task + + async def update_task(self, task: AgentTask) -> None: + self._current_task.chat_ctx = self.chat_ctx_copy() + self._current_task = task + self.session_update(instructions=task.instructions) - @fnc_ctx.setter - def fnc_ctx(self, fnc_ctx: llm.FunctionContext | None) -> None: - self._fnc_ctx = fnc_ctx + # # remove the function calls from the chat context + # chat_ctx = self.chat_ctx_copy() + # chat_ctx = chat_ctx.truncate(keep_last_n=-1, keep_tool_calls=False) + # await self.set_chat_ctx(chat_ctx) @property def conversation(self) -> Conversation: @@ -913,13 +957,14 @@ def session_update( self._opts.max_response_output_tokens = max_response_output_tokens tools = [] - if self._fnc_ctx is not None: - for fnc in self._fnc_ctx.ai_functions.values(): + if self.fnc_ctx is not None: + for fnc in self.fnc_ctx.ai_functions.values(): # the realtime API is using internally-tagged polymorphism. # build_oai_function_description was built for the ChatCompletion API function_data = build_oai_function_description(fnc)["function"] function_data["type"] = "function" tools.append(function_data) + logger.info(f"tools: {tools}") server_vad_opts: api_proto.ServerVad | None = None if self._opts.turn_detection is not None: @@ -1542,7 +1587,7 @@ def _handle_response_output_item_done( output = response.output[output_index] if output.type == "function_call": - if self._fnc_ctx is None: + if self.fnc_ctx is None: logger.error( "function call received but no fnc_ctx is available", extra=self.logging_extra(), @@ -1554,7 +1599,7 @@ def _handle_response_output_item_done( assert item["type"] == "function_call" fnc_call_info = _create_ai_function_info( - self._fnc_ctx, + self.fnc_ctx, item["call_id"], item["name"], item["arguments"], @@ -1680,17 +1725,44 @@ async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str) "executing ai function", extra={ "function": fnc_call_info.function_info.name, + "arguments": fnc_call_info.function_info.arguments, }, ) + tk = _CallContextVar.set(RealtimeCallContext(self)) + called_fnc = fnc_call_info.execute() await called_fnc.task + _CallContextVar.reset(tk) tool_call = llm.ChatMessage.create_tool_from_called_function(called_fnc) + + # switch task + new_task = called_fnc.get_agent_task() + if new_task: + original_task = self._current_task + logger.debug( + "switching to next agent task", + extra={ + "new_task": str(new_task), + "previous_task": str(self.current_task), + }, + ) + await self.update_task(new_task) + + # update the chat context of the original task + original_task.chat_ctx.messages.extend( + [ + llm.ChatMessage.create_tool_calls([called_fnc.call_info]), + tool_call, + ] + ) + logger.info( - "creating response for tool call", + "ai functions done, creating response for tool call", extra={ "function": fnc_call_info.function_info.name, + "result": called_fnc.get_content(), }, ) if tool_call.content is not None: