diff --git a/agent_stream/schedule_agent.py b/agent_stream/schedule_agent.py index 20fdcb35..e3b8b53a 100644 --- a/agent_stream/schedule_agent.py +++ b/agent_stream/schedule_agent.py @@ -3,6 +3,7 @@ import time from restack_ai import Restack +from restack_ai.event import AgentEvent from src.agents.agent import AgentStream @@ -11,7 +12,12 @@ async def main() -> None: agent_id = f"{int(time.time() * 1000)}-{AgentStream.__name__}" run_id = await client.schedule_agent( - agent_name=AgentStream.__name__, agent_id=agent_id + agent_name=AgentStream.__name__, + agent_id=agent_id, + event=AgentEvent( + name="messages", + input={"messages": [{"role": "user", "content": "Tell me a joke"}]}, + ), ) await client.get_agent_result(agent_id=agent_id, run_id=run_id) diff --git a/agent_video/README.md b/agent_video/pipecat/README.md similarity index 76% rename from agent_video/README.md rename to agent_video/pipecat/README.md index 1b794c95..37bfd3ce 100644 --- a/agent_video/README.md +++ b/agent_video/pipecat/README.md @@ -11,7 +11,7 @@ For a complete documentation on how the agent works and how to setup the service - Python 3.10 or higher - Deepgram account (For speech-to-text transcription) - Cartesia account (for text-to-speech and voice cloning) -- Tavus account (for video replica) +- Tavus or Heygen account (for video replica) ## Start Restack @@ -21,9 +21,17 @@ To start the Restack, use the following Docker command: docker run -d --pull always --name restack -p 5233:5233 -p 6233:6233 -p 7233:7233 -p 9233:9233 ghcr.io/restackio/restack:main ``` -## Start python shell +## Configure environment variables -If using uv: +In all subfolders, duplicate the `env.example` file and rename it to `.env`. + +Obtain a Restack API Key to interact with the 'gpt-4o-mini' model at no cost from [Restack Cloud](https://console.restack.io/starter) + +## Start Restack Agent + +in /agent + +### Start python shell ```bash uv venv && source .venv/bin/activate @@ -35,7 +43,7 @@ If using pip: python -m venv .venv && source .venv/bin/activate ``` -## Install dependencies +### Install dependencies If using uv: @@ -51,13 +59,37 @@ pip install -e . python -c "from src.services import watch_services; watch_services()" ``` -## Configure Your Environment Variables +## Start Pipecat pipeline -Duplicate the `env.example` file and rename it to `.env`. +in /pipeline -Obtain a Restack API Key to interact with the 'gpt-4o-mini' model at no cost from [Restack Cloud](https://console.restack.io/starter) +### Start python shell + +```bash +uv venv && source .venv/bin/activate +``` + +If using pip: -## Create Room and run Agent in parallel +```bash +python -m venv .venv && source .venv/bin/activate +``` + +### Install dependencies + +If using uv: + +```bash +uv sync +uv run dev +``` + +If using pip: + +```bash +pip install -e . +python -c "from src.services import watch_services; watch_services()" +``` ### from UI diff --git a/agent_video/.env.example b/agent_video/pipecat/agent/.env.example similarity index 100% rename from agent_video/.env.example rename to agent_video/pipecat/agent/.env.example diff --git a/agent_video/.python-version b/agent_video/pipecat/agent/.python-version similarity index 100% rename from agent_video/.python-version rename to agent_video/pipecat/agent/.python-version diff --git a/agent_video/pipecat/agent/README.md b/agent_video/pipecat/agent/README.md new file mode 100644 index 00000000..6fd018e9 --- /dev/null +++ b/agent_video/pipecat/agent/README.md @@ -0,0 +1,2 @@ + +See parent README.md at /agent_video/pipecat/README.md for instructions on how to run the agent. \ No newline at end of file diff --git a/agent_video/pipecat/agent/pyproject.toml b/agent_video/pipecat/agent/pyproject.toml new file mode 100644 index 00000000..f3b8ef13 --- /dev/null +++ b/agent_video/pipecat/agent/pyproject.toml @@ -0,0 +1,27 @@ +[project] +name = "agent_video_pipecat_agent" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "openai>=1.59.9", + "pipecat-ai[daily]>=0.0.58", + "python-dotenv>=1.0.1", + "pydantic>=2.10.6", + "watchfiles>=1.0.4", + "restack-ai>=0.0.87",] + +[project.scripts] +dev = "src.services:watch_services" +services = "src.services:run_services" + +[tool.hatch.build.targets.sdist] +include = ["src"] + +[tool.hatch.build.targets.wheel] +include = ["src"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" diff --git a/agent_video/src/__init__.py b/agent_video/pipecat/agent/src/__init__.py similarity index 100% rename from agent_video/src/__init__.py rename to agent_video/pipecat/agent/src/__init__.py diff --git a/agent_video/src/agents/__init__.py b/agent_video/pipecat/agent/src/agents/__init__.py similarity index 100% rename from agent_video/src/agents/__init__.py rename to agent_video/pipecat/agent/src/agents/__init__.py diff --git a/agent_video/pipecat/agent/src/agents/agent.py b/agent_video/pipecat/agent/src/agents/agent.py new file mode 100644 index 00000000..f0fd772c --- /dev/null +++ b/agent_video/pipecat/agent/src/agents/agent.py @@ -0,0 +1,164 @@ +from datetime import timedelta +from typing import Literal + +from pydantic import BaseModel +from restack_ai.agent import ( + NonRetryableError, + RetryPolicy, + agent, + import_functions, + log, + uuid, +) + +from src.workflows.logic import LogicWorkflow, LogicWorkflowInput + +with import_functions(): + from src.functions.context_docs import context_docs + from src.functions.daily_send_data import ( + DailySendDataInput, + daily_send_data, + ) + from src.functions.llm_talk import LlmTalkInput, llm_talk, Message, ModelType + + +class MessagesEvent(BaseModel): + messages: list[Message] + + +class EndEvent(BaseModel): + end: bool + +class AgentInput(BaseModel): + room_url: str + model: ModelType + interactive_prompt: str | None = None + reasoning_prompt: str | None = None + + +class ContextEvent(BaseModel): + context: str + + +class DailyMessageEvent(BaseModel): + message: str + recipient: str | None = None + + +@agent.defn() +class AgentVideo: + def __init__(self) -> None: + self.end = False + self.messages: list[Message] = [] + self.room_url = "" + self.model: Literal[ + "restack", "gpt-4o-mini", "gpt-4o", "openpipe:twenty-lions-fall", "ft:gpt-4o-mini-2024-07-18:restack::BJymdMm8" + ] = "restack" + self.interactive_prompt = "" + self.reasoning_prompt = "" + self.context = "" + + @agent.event + async def messages( + self, + messages_event: MessagesEvent, + ) -> list[Message]: + log.info(f"Received message: {messages_event.messages}") + self.messages.extend(messages_event.messages) + try: + await agent.child_start( + workflow=LogicWorkflow, + workflow_id=f"{uuid()}-logic", + workflow_input=LogicWorkflowInput( + messages=self.messages, + room_url=self.room_url, + context=str(self.context), + interactive_prompt=self.interactive_prompt, + reasoning_prompt=self.reasoning_prompt, + model=self.model, + ), + ) + + assistant_message = await agent.step( + function=llm_talk, + function_input=LlmTalkInput( + messages=self.messages[-3:], + context=str(self.context), + mode="default", + model=self.model, + interactive_prompt=self.interactive_prompt, + ), + start_to_close_timeout=timedelta(seconds=3), + retry_policy=RetryPolicy( + initial_interval=timedelta(seconds=1), + maximum_attempts=1, + maximum_interval=timedelta(seconds=5), + ), + ) + + except Exception as e: + error_message = f"llm_chat function failed: {e}" + raise NonRetryableError(error_message) from e + else: + self.messages.append( + Message( + role="assistant", + content=str(assistant_message), + ), + ) + return self.messages + + @agent.event + async def end(self, end: EndEvent) -> EndEvent: + log.info("Received end") + self.end = True + return end + + @agent.event + async def context(self, context: ContextEvent) -> str: + log.info("Received context") + self.context = context.context + return self.context + + @agent.event + async def daily_message( + self, daily_message: DailyMessageEvent + ) -> bool: + log.info("Received message", daily_message=daily_message) + await agent.step( + function=daily_send_data, + function_input=DailySendDataInput( + room_url=self.room_url, + data={ + "text": daily_message.message, + "author": "agent", + }, + recipient=daily_message.recipient, + ), + ) + return True + + @agent.run + async def run(self, agent_input: AgentInput) -> None: + try: + self.room_url = agent_input.room_url + self.model = agent_input.model + self.interactive_prompt = ( + agent_input.interactive_prompt + ) + self.reasoning_prompt = agent_input.reasoning_prompt + docs = await agent.step(function=context_docs) + except Exception as e: + error_message = f"context_docs function failed: {e}" + raise NonRetryableError(error_message) from e + else: + system_prompt = f""" + You are an AI assistant for Restack. You can answer questions about the following documentation: + {docs} + {self.interactive_prompt} + """ + self.messages.append( + Message(role="system", content=system_prompt), + ) + + await agent.condition(lambda: self.end) diff --git a/agent_video/src/client.py b/agent_video/pipecat/agent/src/client.py similarity index 82% rename from agent_video/src/client.py rename to agent_video/pipecat/agent/src/client.py index 885bf8ea..2ff14fbb 100644 --- a/agent_video/src/client.py +++ b/agent_video/pipecat/agent/src/client.py @@ -14,6 +14,9 @@ api_address = os.getenv("RESTACK_ENGINE_API_ADDRESS") connection_options = CloudConnectionOptions( - engine_id=engine_id, address=address, api_key=api_key, api_address=api_address + engine_id=engine_id, + address=address, + api_key=api_key, + api_address=api_address, ) client = Restack(connection_options) diff --git a/agent_video/src/functions/__init__.py b/agent_video/pipecat/agent/src/functions/__init__.py similarity index 100% rename from agent_video/src/functions/__init__.py rename to agent_video/pipecat/agent/src/functions/__init__.py diff --git a/agent_video/src/functions/context_docs.py b/agent_video/pipecat/agent/src/functions/context_docs.py similarity index 64% rename from agent_video/src/functions/context_docs.py rename to agent_video/pipecat/agent/src/functions/context_docs.py index 8bda755b..f53759da 100644 --- a/agent_video/src/functions/context_docs.py +++ b/agent_video/pipecat/agent/src/functions/context_docs.py @@ -7,15 +7,22 @@ async def fetch_content_from_url(url: str) -> str: async with session.get(url) as response: if response.status == 200: return await response.text() - error_message = f"Failed to fetch content: {response.status}" + error_message = ( + f"Failed to fetch content: {response.status}" + ) raise NonRetryableError(error_message) @function.defn() async def context_docs() -> str: try: - docs_content = await fetch_content_from_url("https://docs.restack.io/llms-full.txt") - log.info("Fetched content from URL", content=len(docs_content)) + docs_content = await fetch_content_from_url( + "https://docs.restack.io/llms-full.txt", + ) + log.info( + "Fetched content from URL", + content=len(docs_content), + ) return docs_content diff --git a/agent_video/pipecat/agent/src/functions/daily_create_room.py b/agent_video/pipecat/agent/src/functions/daily_create_room.py new file mode 100644 index 00000000..687aaa0c --- /dev/null +++ b/agent_video/pipecat/agent/src/functions/daily_create_room.py @@ -0,0 +1,81 @@ +import os + +import aiohttp +from dotenv import load_dotenv +from pipecat.transports.services.helpers.daily_rest import ( + DailyRESTHelper, + DailyRoomParams, + DailyRoomProperties, +) +from pydantic import BaseModel +from restack_ai.function import ( + NonRetryableError, + function, + log, +) + +# Load environment variables from .env file +load_dotenv() + + +class DailyRoomOutput(BaseModel): + room_url: str + token: str + + +class DailyRoomInput(BaseModel): + room_name: str + + +@function.defn(name="daily_create_room") +async def daily_create_room( + function_input: DailyRoomInput, +) -> DailyRoomOutput: + try: + api_key = os.getenv("DAILYCO_API_KEY") + if not api_key: + raise ValueError( + "DAILYCO_API_KEY not set in environment.", + ) + + async with aiohttp.ClientSession() as daily_session: + daily_rest_helper = DailyRESTHelper( + daily_api_key=api_key, + daily_api_url="https://api.daily.co/v1", + aiohttp_session=daily_session, + ) + + room = await daily_rest_helper.create_room( + params=DailyRoomParams( + name=function_input.room_name, + properties=DailyRoomProperties( + start_video_off=True, + start_audio_off=False, + max_participants=2, + enable_prejoin_ui=False, + ), + ), + ) + + # Create a meeting token for the given room with an expiration 1 hour in + # the future. + expiry_time: float = 60 * 60 + + token = await daily_rest_helper.get_token( + room.url, + expiry_time, + ) + + if not token: + raise NonRetryableError( + "No session token found in the response.", + ) + + log.info("daily_room token", token=token) + return DailyRoomOutput(room_url=room.url, token=token) + + except Exception as e: + log.error("Error creating daily room", error=e) + raise NonRetryableError( + f"Error creating daily room: {e}", + ) from e diff --git a/agent_video/pipecat/agent/src/functions/daily_send_data.py b/agent_video/pipecat/agent/src/functions/daily_send_data.py new file mode 100644 index 00000000..fbd22223 --- /dev/null +++ b/agent_video/pipecat/agent/src/functions/daily_send_data.py @@ -0,0 +1,67 @@ +import os + +import aiohttp +from dotenv import load_dotenv +from pydantic import BaseModel +from restack_ai.function import ( + NonRetryableError, + function, + log, +) + +load_dotenv() + + +async def send_data_to_room( + room_name: str, data: dict, recipient: str | None = "*" +) -> bool: + """Send a message to a Daily room.""" + api_key = os.getenv("DAILYCO_API_KEY") + if not api_key: + raise ValueError( + "DAILYCO_API_KEY not set in environment." + ) + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + url = f"https://api.daily.co/v1/rooms/{room_name}/send-app-message" + recipient = recipient or "*" + data = {"data": data, "recipient": recipient} + + async with aiohttp.ClientSession() as session: + async with session.post( + url, headers=headers, json=data + ) as response: + if response.status != 200: + text = await response.text() + raise Exception( + f"Failed to send message (status: {response.status}): {text}" + ) + + return True + + +class DailySendDataInput(BaseModel): + room_url: str + data: dict + recipient: str | None = "*" + + +@function.defn(name="daily_send_data") +async def daily_send_data( + function_input: DailySendDataInput, +) -> bool: + try: + return await send_data_to_room( + room_name=function_input.room_url.split("/")[-1], + data=function_input.data, + recipient=function_input.recipient, + ) + except Exception as e: + log.error("Error sending message to daily room", error=e) + raise NonRetryableError( + f"Error sending message to daily room: {e}", + ) from e diff --git a/agent_video/pipecat/agent/src/functions/llm_logic.py b/agent_video/pipecat/agent/src/functions/llm_logic.py new file mode 100644 index 00000000..eab7b1c3 --- /dev/null +++ b/agent_video/pipecat/agent/src/functions/llm_logic.py @@ -0,0 +1,63 @@ +import os +from typing import Literal + +from openai import AsyncOpenAI +from pydantic import BaseModel +from restack_ai.function import NonRetryableError, function + +from src.functions.llm_talk import Message + + +class LlmLogicResponse(BaseModel): + """Structured AI decision output used to interrupt conversations.""" + + action: Literal["interrupt", "update_context", "end_call"] + reason: str + updated_context: str + + +class LlmLogicInput(BaseModel): + messages: list[Message] + documentation: str + reasoning_prompt: str | None = None + + +@function.defn() +async def llm_logic( + function_input: LlmLogicInput, +) -> LlmLogicResponse: + try: + client = AsyncOpenAI( + api_key=os.environ.get("OPENAI_API_KEY") + ) + + if function_input.reasoning_prompt: + system_prompt = ( + function_input.reasoning_prompt + + f"Restack Documentation: {function_input.documentation}" + ) + else: + system_prompt = ( + "Analyze the developer's questions and determine if an interruption is needed. " + "For example, to ask a follow up question and keep the conversation going. " + "Use the Restack documentation for accurate answers. " + "Track what the developer has learned and update their belief state." + f"Restack Documentation: {function_input.documentation}" + ) + + response = await client.beta.chat.completions.parse( + model="gpt-4o", + messages=[ + { + "role": "system", + "content": system_prompt, + }, + *function_input.messages, + ], + response_format=LlmLogicResponse, + ) + + return response.choices[0].message.parsed + + except Exception as e: + raise NonRetryableError(f"llm_slow failed: {e}") from e diff --git a/agent_video/pipecat/agent/src/functions/llm_talk.py b/agent_video/pipecat/agent/src/functions/llm_talk.py new file mode 100644 index 00000000..629a5d8d --- /dev/null +++ b/agent_video/pipecat/agent/src/functions/llm_talk.py @@ -0,0 +1,93 @@ +import os +from typing import Literal + +from openai import OpenAI +from pydantic import BaseModel, Field +from restack_ai.function import ( + NonRetryableError, + function, + stream_to_websocket, +) + +from src.client import api_address + +class Message(BaseModel): + role: str + content: str + +ModelType = Literal["gpt-4o-mini", "ft:gpt-4o-mini-2024-07-18:restack::BJymdMm8", "openpipe:twenty-lions-fall"] + +class LlmTalkInput(BaseModel): + messages: list[Message] = Field(default_factory=list) + context: str | None = None # Updated context from Slow AI + mode: Literal["default", "interrupt"] + stream: bool = True + model: ModelType = "gpt-4o-mini" + interactive_prompt: str | None = None + + +@function.defn() +async def llm_talk(function_input: LlmTalkInput) -> str: + """Fast AI generates responses while checking for memory updates.""" + try: + # if model starts with openpipe use base_url + if function_input.model.startswith("openpipe:"): + client = OpenAI( + api_key=os.environ.get("OPENPIPE_API_KEY"), + base_url=f"https://app.openpipe.ai/api/v1/", + ) + else: + client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + + interactive_prompt = function_input.interactive_prompt + + if interactive_prompt: + system_prompt = ( + interactive_prompt + + "Current context: " + + function_input.context + ) + + else: + common_prompt = ( + "Your are an AI assistant helping developers build with restack: the backend framework for accurate & reliable AI agents." + "Your interface with users will be voice. Be friendly, helpful and avoid usage of unpronouncable punctuation." + "Always try to bring back the conversation to restack if the user is talking about something else. " + "Current context: " + function_input.context + ) + + if function_input.mode == "default": + system_prompt = ( + common_prompt + + "If you don't know an answer, **do not make something up**. Instead, be friendly and acknowledge that " + "you will check for the correct response and let the user know. Keep your answer short in max 20 words" + ) + else: + system_prompt = ( + common_prompt + + "You are providing a short and precise update based on new information. " + "Do not re-explain everything, just deliver the most important update. Keep your answer short in max 20 words unless the user asks for more information." + ) + + function_input.messages.insert( + 0, Message(role="system", content=system_prompt) + ) + + messages_dicts = [ + msg.model_dump() for msg in function_input.messages + ] + + response = client.chat.completions.create( + model=function_input.model, + messages=messages_dicts, + stream=function_input.stream, + ) + + if function_input.stream: + return await stream_to_websocket( + api_address=api_address, data=response + ) + return response.choices[0].message.content + + except Exception as e: + raise NonRetryableError(f"llm_talk failed: {e}") from e diff --git a/agent_video/pipecat/agent/src/functions/send_agent_event.py b/agent_video/pipecat/agent/src/functions/send_agent_event.py new file mode 100644 index 00000000..33edcc34 --- /dev/null +++ b/agent_video/pipecat/agent/src/functions/send_agent_event.py @@ -0,0 +1,31 @@ +from typing import Any + +from pydantic import BaseModel +from restack_ai.function import NonRetryableError, function + +from src.client import client + + +class SendAgentEventInput(BaseModel): + event_name: str + agent_id: str + run_id: str | None = None + event_input: dict[str, Any] | None = None + + +@function.defn() +async def send_agent_event( + function_input: SendAgentEventInput, +) -> str: + try: + return await client.send_agent_event( + event_name=function_input.event_name, + agent_id=function_input.agent_id, + run_id=function_input.run_id, + event_input=function_input.event_input, + ) + + except Exception as e: + raise NonRetryableError( + f"send_agent_event failed: {e}" + ) from e diff --git a/agent_video/pipecat/agent/src/functions/tavus_create_room.py b/agent_video/pipecat/agent/src/functions/tavus_create_room.py new file mode 100644 index 00000000..e4f0d41f --- /dev/null +++ b/agent_video/pipecat/agent/src/functions/tavus_create_room.py @@ -0,0 +1,56 @@ +import os + +import aiohttp +from dotenv import load_dotenv +from pydantic import BaseModel +from restack_ai.function import ( + NonRetryableError, + function, + log, +) + +# Load environment variables from .env file +load_dotenv() + + +class TavusRoomOutput(BaseModel): + room_url: str + + +@function.defn(name="tavus_create_room") +async def tavus_create_room() -> TavusRoomOutput: + try: + api_key = os.getenv("TAVUS_API_KEY") + replica_id = os.getenv("TAVUS_REPLICA_ID") + if not api_key or not replica_id: + raise ValueError( + "TAVUS_API_KEY or TAVUS_REPLICA_ID not set in environment.", + ) + + async with aiohttp.ClientSession() as session: + url = "https://tavusapi.com/v2/conversations" + headers = { + "Content-Type": "application/json", + "x-api-key": api_key, + } + payload = { + "replica_id": replica_id, + "persona_id": "pipecat0", + } + + async with session.post( + url, headers=headers, json=payload + ) as r: + r.raise_for_status() + response_json = await r.json() + + log.info("Tavus room created", response=response_json) + return TavusRoomOutput( + room_url=response_json["conversation_url"], + ) + + except Exception as e: + log.error("Error creating Tavus room", error=e) + raise NonRetryableError( + f"Error creating Tavus room: {e}", + ) from e diff --git a/agent_video/pipecat/agent/src/services.py b/agent_video/pipecat/agent/src/services.py new file mode 100644 index 00000000..94cb7d58 --- /dev/null +++ b/agent_video/pipecat/agent/src/services.py @@ -0,0 +1,61 @@ +import asyncio +import logging +import webbrowser +from pathlib import Path + +from restack_ai.restack import ServiceOptions +from watchfiles import run_process + +from src.agents.agent import AgentVideo +from src.client import client +from src.functions.context_docs import context_docs +from src.functions.daily_create_room import daily_create_room +from src.functions.daily_send_data import daily_send_data +from src.functions.llm_logic import llm_logic +from src.functions.llm_talk import llm_talk +from src.functions.send_agent_event import send_agent_event +from src.functions.tavus_create_room import tavus_create_room +from src.workflows.logic import LogicWorkflow +from src.workflows.room import RoomWorkflow + + +async def main() -> None: + await client.start_service( + agents=[AgentVideo], + workflows=[RoomWorkflow, LogicWorkflow], + functions=[ + llm_logic, + llm_talk, + context_docs, + daily_create_room, + tavus_create_room, + daily_send_data, + send_agent_event, + ], + options=ServiceOptions( + endpoint_group="agent_video", # used to locally show both agent and pipeline endpoint in UI + ), + ) + + +def run_services() -> None: + try: + asyncio.run(main()) + except KeyboardInterrupt: + logging.info( + "Service interrupted by user. Exiting gracefully.", + ) + + +def watch_services() -> None: + watch_path = Path.cwd() + logging.info( + "Watching %s and its subdirectories for changes...", + watch_path, + ) + webbrowser.open("http://localhost:5233") + run_process(watch_path, recursive=True, target=run_services) + + +if __name__ == "__main__": + run_services() diff --git a/agent_video/pipecat/agent/src/workflows/logic.py b/agent_video/pipecat/agent/src/workflows/logic.py new file mode 100644 index 00000000..b6ec042b --- /dev/null +++ b/agent_video/pipecat/agent/src/workflows/logic.py @@ -0,0 +1,170 @@ +from datetime import timedelta +from typing import Literal + +from pydantic import BaseModel +from restack_ai.workflow import ( + NonRetryableError, + import_functions, + log, + workflow, + workflow_info, +) + +with import_functions(): + from src.functions.context_docs import context_docs + from src.functions.daily_send_data import ( + DailySendDataInput, + daily_send_data, + ) + from src.functions.llm_logic import ( + LlmLogicInput, + LlmLogicResponse, + llm_logic, + ) + from src.functions.llm_talk import ( + LlmTalkInput, + Message, + llm_talk, + LlmTalkInput, + ModelType + ) + from src.functions.send_agent_event import ( + SendAgentEventInput, + send_agent_event, + ) + + +class LogicWorkflowInput(BaseModel): + messages: list[Message] + context: str + room_url: str + interactive_prompt: str | None = None + reasoning_prompt: str | None = None + model: ModelType + + +class LogicWorkflowOutput(BaseModel): + result: str + room_url: str + reasoning_prompt: str | None = None + + +@workflow.defn() +class LogicWorkflow: + @workflow.run + async def run( + self, workflow_input: LogicWorkflowInput + ) -> str: + context = workflow_input.context + + parent_agent_id = workflow_info().parent.workflow_id + parent_agent_run_id = workflow_info().parent.run_id + + log.info("LogicWorkflow started") + try: + documentation = await workflow.step( + function=context_docs + ) + + slow_response: LlmLogicResponse = await workflow.step( + function=llm_logic, + function_input=LlmLogicInput( + messages=[ + msg.model_dump() + for msg in workflow_input.messages + ], + documentation=documentation, + reasoning_prompt=workflow_input.reasoning_prompt, + ), + start_to_close_timeout=timedelta(seconds=60), + ) + + log.info(f"Slow response: {slow_response}") + + context = slow_response.updated_context + + await workflow.step( + function=send_agent_event, + function_input=SendAgentEventInput( + event_name="context", + agent_id=parent_agent_id, + run_id=parent_agent_run_id, + event_input={"context": str(context)}, + ), + ) + + if slow_response.action == "interrupt": + interrupt_response = await workflow.step( + function=llm_talk, + function_input=LlmTalkInput( + messages=[ + Message( + role="system", + content=slow_response.reason, + ) + ], + context=str(context), + mode="interrupt", + stream=False, + model=workflow_input.model, + interactive_prompt=workflow_input.interactive_prompt, + ), + start_to_close_timeout=timedelta(seconds=3), + ) + + await workflow.step( + function=daily_send_data, + function_input=DailySendDataInput( + room_url=workflow_input.room_url, + data={"text": interrupt_response}, + ), + ) + + if slow_response.action == "end_call": + goodbye_message = await workflow.step( + function=llm_talk, + function_input=LlmTalkInput( + messages=[ + Message( + role="system", + content="Say goodbye to the user by providing a unique and short message based on context.", + ) + ], + context=str(context), + mode="interrupt", + model=workflow_input.model, + stream=False, + ), + start_to_close_timeout=timedelta(seconds=3), + ) + + log.info(f"Goodbye message: {goodbye_message}") + + await workflow.step( + function=daily_send_data, + function_input=DailySendDataInput( + room_url=workflow_input.room_url, + data={"text": goodbye_message}, + ), + ) + + await workflow.sleep(1) + + await workflow.step( + function=send_agent_event, + function_input=SendAgentEventInput( + event_name="end", + agent_id=parent_agent_id, + run_id=parent_agent_run_id, + event_input={"end": True}, + ), + ) + + except Exception as e: + error_message = f"Error during welcome: {e}" + raise NonRetryableError(error_message) from e + else: + log.info( + "LogicWorkflow completed", context=str(context) + ) + return str(context) diff --git a/agent_video/pipecat/agent/src/workflows/room.py b/agent_video/pipecat/agent/src/workflows/room.py new file mode 100644 index 00000000..a1477c31 --- /dev/null +++ b/agent_video/pipecat/agent/src/workflows/room.py @@ -0,0 +1,131 @@ +from datetime import timedelta +from typing import Literal + +from pydantic import BaseModel +from restack_ai.workflow import ( + NonRetryableError, + ParentClosePolicy, + import_functions, + log, + workflow, + workflow_info, +) +from src.agents.agent import AgentInput, AgentVideo + +with import_functions(): + from src.functions.daily_create_room import ( + DailyRoomInput, + daily_create_room, + ) + from src.functions.tavus_create_room import tavus_create_room + + +class RoomWorkflowOutput(BaseModel): + agent_name: str + agent_id: str + agent_run_id: str + room_url: str + token: str | None = None + + +class RoomWorkflowInput(BaseModel): + video_service: Literal["tavus", "heygen", "audio"] + model: Literal["gpt-4o-mini", "openpipe:twenty-lions-fall", "ft:gpt-4o-mini-2024-07-18:restack::BJymdMm8"] = "gpt-4o-mini" + interactive_prompt: str | None = None + reasoning_prompt: str | None = None + + +class PipelineWorkflowInput(BaseModel): + video_service: Literal["tavus", "heygen", "audio"] + agent_name: str + agent_id: str + agent_run_id: str + daily_room_url: str | None = None + daily_room_token: str | None = None + + +@workflow.defn() +class RoomWorkflow: + @workflow.run + async def run( + self, workflow_input: RoomWorkflowInput + ) -> RoomWorkflowOutput: + try: + daily_room = None + room_url = None + token = None + + agent_id = f"{workflow_info().workflow_id}-agent" + pipeline_id = ( + f"{workflow_info().workflow_id}-pipeline" + ) + + if workflow_input.video_service == "heygen": + daily_room = await workflow.step( + function=daily_create_room, + function_input=DailyRoomInput( + room_name=workflow_info().run_id, + ), + ) + room_url = daily_room.room_url + token = daily_room.token + + if workflow_input.video_service == "audio": + daily_room = await workflow.step( + function=daily_create_room, + function_input=DailyRoomInput( + room_name=workflow_info().run_id, + ), + ) + room_url = daily_room.room_url + token = daily_room.token + + if workflow_input.video_service == "tavus": + tavus_room = await workflow.step( + function=tavus_create_room, + ) + room_url = tavus_room.room_url + + agent = await workflow.child_start( + agent=AgentVideo, + agent_id=agent_id, + agent_input=AgentInput( + room_url=room_url, + model=workflow_input.model, + interactive_prompt=workflow_input.interactive_prompt, + reasoning_prompt=workflow_input.reasoning_prompt, + ), + start_to_close_timeout=timedelta(minutes=20), + parent_close_policy=ParentClosePolicy.ABANDON, + ) + + await workflow.child_start( + task_queue="pipeline", + workflow="PipelineWorkflow", + workflow_id=pipeline_id, + workflow_input=PipelineWorkflowInput( + video_service=workflow_input.video_service, + agent_name=AgentVideo.__name__, + agent_id=agent.id, + agent_run_id=agent.run_id, + daily_room_url=room_url if room_url else None, + daily_room_token=token if token else None, + ), + start_to_close_timeout=timedelta(minutes=20), + parent_close_policy=ParentClosePolicy.ABANDON, + ) + + except Exception as e: + error_message = f"Error during PipelineWorkflow: {e}" + raise NonRetryableError(error_message) from e + + else: + log.info("RoomWorkflow completed", room_url=room_url) + + return RoomWorkflowOutput( + agent_name=AgentVideo.__name__, + agent_id=agent.id, + agent_run_id=agent.run_id, + room_url=room_url, + token=token if token else None, + ) diff --git a/agent_video/agent_messages.png b/agent_video/pipecat/agent_messages.png similarity index 100% rename from agent_video/agent_messages.png rename to agent_video/pipecat/agent_messages.png diff --git a/agent_video/pipecat/pipeline/.env.example b/agent_video/pipecat/pipeline/.env.example new file mode 100644 index 00000000..9e1d62b6 --- /dev/null +++ b/agent_video/pipecat/pipeline/.env.example @@ -0,0 +1,7 @@ +DEEPGRAM_API_KEY= +CARTESIA_API_KEY= +CARTESIA_VOICE_ID= +ELEVENLABS_API_KEY= + +TAVUS_API_KEY= +TAVUS_REPLICA_ID= \ No newline at end of file diff --git a/agent_video/pipecat/pipeline/.python-version b/agent_video/pipecat/pipeline/.python-version new file mode 100644 index 00000000..e4fba218 --- /dev/null +++ b/agent_video/pipecat/pipeline/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/agent_video/pipecat/pipeline/Dockerfile b/agent_video/pipecat/pipeline/Dockerfile new file mode 100644 index 00000000..679cdafc --- /dev/null +++ b/agent_video/pipecat/pipeline/Dockerfile @@ -0,0 +1,16 @@ +FROM ghcr.io/astral-sh/uv:python3.10-bookworm-slim + +WORKDIR /app + +COPY pyproject.toml ./ + +COPY . . + +# Install dependencies +RUN uv sync --no-dev + +# Expose port 80 +EXPOSE 80 + +CMD ["uv", "run", "services"] + diff --git a/agent_video/pipecat/pipeline/README.md b/agent_video/pipecat/pipeline/README.md new file mode 100644 index 00000000..6fd018e9 --- /dev/null +++ b/agent_video/pipecat/pipeline/README.md @@ -0,0 +1,2 @@ + +See parent README.md at /agent_video/pipecat/README.md for instructions on how to run the agent. \ No newline at end of file diff --git a/agent_video/pyproject.toml b/agent_video/pipecat/pipeline/pyproject.toml similarity index 86% rename from agent_video/pyproject.toml rename to agent_video/pipecat/pipeline/pyproject.toml index de0dc69f..8ddbec11 100644 --- a/agent_video/pyproject.toml +++ b/agent_video/pipecat/pipeline/pyproject.toml @@ -1,5 +1,5 @@ [project] -name = "agent_video" +name = "agent_video_pipecat_pipeline" version = "0.1.0" description = "Add your description here" readme = "README.md" @@ -10,7 +10,9 @@ dependencies = [ "python-dotenv>=1.0.1", "pydantic>=2.10.6", "watchfiles>=1.0.4", - "restack-ai>=0.0.81",] + "restack-ai>=0.0.87", + "livekit>=0.21.3", +] [project.scripts] dev = "src.services:watch_services" diff --git a/agent_video/src/workflows/__init__.py b/agent_video/pipecat/pipeline/src/__init__.py similarity index 100% rename from agent_video/src/workflows/__init__.py rename to agent_video/pipecat/pipeline/src/__init__.py diff --git a/agent_video/pipecat/pipeline/src/client.py b/agent_video/pipecat/pipeline/src/client.py new file mode 100644 index 00000000..2ff14fbb --- /dev/null +++ b/agent_video/pipecat/pipeline/src/client.py @@ -0,0 +1,22 @@ +import os + +from dotenv import load_dotenv +from restack_ai import Restack +from restack_ai.restack import CloudConnectionOptions + +# Load environment variables from a .env file +load_dotenv() + + +engine_id = os.getenv("RESTACK_ENGINE_ID") +address = os.getenv("RESTACK_ENGINE_ADDRESS") +api_key = os.getenv("RESTACK_ENGINE_API_KEY") +api_address = os.getenv("RESTACK_ENGINE_API_ADDRESS") + +connection_options = CloudConnectionOptions( + engine_id=engine_id, + address=address, + api_key=api_key, + api_address=api_address, +) +client = Restack(connection_options) diff --git a/agent_video/pipecat/pipeline/src/functions/__init__.py b/agent_video/pipecat/pipeline/src/functions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agent_video/pipecat/pipeline/src/functions/daily_delete_room.py b/agent_video/pipecat/pipeline/src/functions/daily_delete_room.py new file mode 100644 index 00000000..928bb295 --- /dev/null +++ b/agent_video/pipecat/pipeline/src/functions/daily_delete_room.py @@ -0,0 +1,56 @@ +import os + +import aiohttp +from dotenv import load_dotenv +from pipecat.transports.services.helpers.daily_rest import ( + DailyRESTHelper, +) +from pydantic import BaseModel +from restack_ai.function import ( + NonRetryableError, + function, + log, +) + +# Load environment variables from .env file +load_dotenv() + + +class DailyDeleteRoomInput(BaseModel): + room_name: str + + +@function.defn(name="daily_delete_room") +async def daily_delete_room( + function_input: DailyDeleteRoomInput, +) -> bool: + try: + api_key = os.getenv("DAILYCO_API_KEY") + if not api_key: + raise ValueError( + "DAILYCO_API_KEY not set in environment.", + ) + + async with aiohttp.ClientSession() as daily_session: + daily_rest_helper = DailyRESTHelper( + daily_api_key=api_key, + daily_api_url="https://api.daily.co/v1", + aiohttp_session=daily_session, + ) + + deleted_room = ( + await daily_rest_helper.delete_room_by_name( + function_input.room_name + ) + ) + + log.info( + "daily_room deleted", deleted_room=deleted_room + ) + return deleted_room + + except Exception as e: + log.error("Error deleting daily room", error=e) + raise NonRetryableError( + f"Error deleting daily room: {e}", + ) from e diff --git a/agent_video/pipecat/pipeline/src/functions/deprecated_utils/aiohttp_session.py b/agent_video/pipecat/pipeline/src/functions/deprecated_utils/aiohttp_session.py new file mode 100644 index 00000000..704640ed --- /dev/null +++ b/agent_video/pipecat/pipeline/src/functions/deprecated_utils/aiohttp_session.py @@ -0,0 +1,43 @@ +import aiohttp +from aiohttp import ClientSession + +# Add a module-level dictionary for reusing ClientSession keyed by workflow_run_id +SESSIONS: dict[str, ClientSession] = {} + +from restack_ai.function import ( + NonRetryableError, + function, + function_info, + log, +) + + +@function.defn(name="create_aiohttp_session") +async def create_aiohttp_session() -> str: + try: + workflow_run_id = function_info().workflow_run_id + if workflow_run_id not in SESSIONS: + SESSIONS[workflow_run_id] = aiohttp.ClientSession() + return workflow_run_id + except Exception as e: + log.error("aiohttp_session error", error=e) + raise NonRetryableError( + f"aiohttp_session error: {e}", + ) from e + + +@function.defn(name="get_aiohttp_session") +async def get_aiohttp_session( + workflow_run_id: str, +) -> ClientSession: + """Retrieve the stored aiohttp ClientSession for the given workflow_run_id.""" + try: + return SESSIONS[workflow_run_id] + except KeyError: + log.error( + "get_aiohttp_session: No session found for workflow_run_id", + workflow_run_id=workflow_run_id, + ) + raise NonRetryableError( + f"No session found for workflow_run_id: {workflow_run_id}", + ) diff --git a/agent_video/pipecat/pipeline/src/functions/deprecated_utils/heygen_streaming_session.py b/agent_video/pipecat/pipeline/src/functions/deprecated_utils/heygen_streaming_session.py new file mode 100644 index 00000000..440512de --- /dev/null +++ b/agent_video/pipecat/pipeline/src/functions/deprecated_utils/heygen_streaming_session.py @@ -0,0 +1,92 @@ +import os + +from dotenv import load_dotenv +from pydantic import BaseModel +from restack_ai.function import ( + NonRetryableError, + function, + function_info, + log, +) +from src.functions.aiohttp_session import get_aiohttp_session + +# Load environment variables from .env file +load_dotenv() + + +class HeygenStreamingSessionOutput(BaseModel): + session_id: str + access_token: str + realtime_endpoint: str + url: str + + +@function.defn(name="heygen_streaming_session") +async def heygen_streaming_session() -> ( + HeygenStreamingSessionOutput +): + try: + api_key = os.getenv("HEYGEN_API_KEY") + if not api_key: + raise ValueError( + "HEYGEN_API_KEY not set in environment.", + ) + + session = await get_aiohttp_session( + function_info().workflow_run_id, + ) + + url = "https://api.heygen.com/v1/streaming.new" + payload = { + "avatarName": "Bryan_IT_Sitting_public", + "version": "v2", + "video_encoding": "H264", + "source": "sdk", + } + headers = { + "accept": "application/json", + "content-type": "application/json", + "x-api-key": api_key, + } + + async with session.post( + url, + json=payload, + headers=headers, + ) as response: + if response.status != 200: + raise NonRetryableError( + f"Error: Received status code {response.status} with details: {await response.text()}", + ) + data = (await response.json()).get("data", {}) + + log.info("Heygen streaming session data", data=data) + + session_id = data.get("session_id") + access_token = data.get("access_token") + realtime_endpoint = data.get("realtime_endpoint") + url_value = data.get("url") + + if ( + not session_id + or not access_token + or not realtime_endpoint + or not url_value + ): + log.error( + "Incomplete Heygen streaming session response", + data=data, + ) + raise NonRetryableError( + "Incomplete Heygen streaming session response: missing one of session_id, access_token, realtime_endpoint, or url.", + ) + return HeygenStreamingSessionOutput( + session_id=session_id, + access_token=access_token, + realtime_endpoint=realtime_endpoint, + url=url_value, + ) + except Exception as e: + raise NonRetryableError( + f"heygen_streaming_session error: {e}", + ) from e diff --git a/agent_video/pipecat/pipeline/src/functions/deprecated_utils/heygen_streaming_stop.py b/agent_video/pipecat/pipeline/src/functions/deprecated_utils/heygen_streaming_stop.py new file mode 100644 index 00000000..22a71812 --- /dev/null +++ b/agent_video/pipecat/pipeline/src/functions/deprecated_utils/heygen_streaming_stop.py @@ -0,0 +1,57 @@ +import os + +import requests +from dotenv import load_dotenv +from pydantic import BaseModel +from restack_ai.function import NonRetryableError, function, log + +# Load environment variables from .env file +load_dotenv() + + +class HeygenStreamingStopInput(BaseModel): + session_id: str + + +@function.defn(name="heygen_streaming_stop") +async def heygen_streaming_stop( + function_input: HeygenStreamingStopInput, +) -> bool: + try: + api_key = os.getenv("HEYGEN_API_KEY") + if not api_key: + raise ValueError( + "HEYGEN_API_KEY not set in environment.", + ) + + url = "https://api.heygen.com/v1/streaming.stop" + payload = { + "session_id": function_input.session_id, + } + headers = { + "accept": "application/json", + "content-type": "application/json", + "x-api-key": api_key, + } + + response = requests.post( + url, + json=payload, + headers=headers, + ) + if response.status_code != 200: + raise NonRetryableError( + f"Error: Received status code {response.status_code} with details: {response.text}", + ) + + message = response.json().get("message", {}) + + log.info("Heygen streaming session stop", message=message) + + if message == "success": + return True + return False + except Exception as e: + raise NonRetryableError( + f"heygen_streaming_session error: {e}", + ) from e diff --git a/agent_video/pipecat/pipeline/src/functions/deprecated_utils/heygen_token.py b/agent_video/pipecat/pipeline/src/functions/deprecated_utils/heygen_token.py new file mode 100644 index 00000000..2e9f7bac --- /dev/null +++ b/agent_video/pipecat/pipeline/src/functions/deprecated_utils/heygen_token.py @@ -0,0 +1,38 @@ +import os + +import requests +from dotenv import load_dotenv +from restack_ai.function import NonRetryableError, function, log + +# Load environment variables from .env file +load_dotenv() + + +@function.defn(name="heygen_token") +async def heygen_token(): + api_key = os.getenv("HEYGEN_API_KEY") + if not api_key: + raise ValueError("HEYGEN_API_KEY not set in environment.") + + url = "https://api.heygen.com/v1/streaming.create_token" + payload = {} + headers = { + "accept": "application/json", + "content-type": "application/json", + "x-api-key": api_key, + } + + response = requests.post(url, json=payload, headers=headers) + if response.status_code != 200: + raise NonRetryableError( + f"Error: Received status code {response.status_code} with details: {response.text}", + ) + + token = response.json().get("data", {}).get("token") + log.info("Heygen token", token=token) + if not token: + raise NonRetryableError( + "No session token found in the response.", + ) + + return token diff --git a/agent_video/pipecat/pipeline/src/functions/heygen_client.py b/agent_video/pipecat/pipeline/src/functions/heygen_client.py new file mode 100644 index 00000000..f71a479e --- /dev/null +++ b/agent_video/pipecat/pipeline/src/functions/heygen_client.py @@ -0,0 +1,191 @@ +from enum import Enum +from typing import Any, Literal + +import aiohttp +from loguru import logger +from pydantic import BaseModel, Field + + +class AvatarQuality(str, Enum): + low = "low" + medium = "medium" + high = "high" + + +class VoiceEmotion(str, Enum): + EXCITED = "excited" + SERIOUS = "serious" + FRIENDLY = "friendly" + SOOTHING = "soothing" + BROADCASTER = "broadcaster" + + +class ElevenLabsSettings(BaseModel): + stability: float | None = None + similarity_boost: float | None = None + style: int | None = None + use_speaker_boost: bool | None = None + + +class VoiceSettings(BaseModel): + voice_id: str | None = Field(None, alias="voiceId") + rate: float | None = None + emotion: VoiceEmotion | None = None + elevenlabs_settings: ElevenLabsSettings | None = Field( + None, + alias="elevenlabsSettings", + ) + + +class NewSessionRequest(BaseModel): + avatarName: str + quality: AvatarQuality | None = None + knowledgeId: str | None = None + knowledgeBase: str | None = None + voice: VoiceSettings | None = None + language: str | None = None + version: Literal["v2"] = "v2" + video_encoding: Literal["H264"] = "H264" + source: Literal["sdk"] = "sdk" + disableIdleTimeout: bool | None = None + + +class SessionResponse(BaseModel): + session_id: str + access_token: str + realtime_endpoint: str + url: str + + +class HeygenAPIError(Exception): + """Custom exception for HeyGen API errors.""" + + def __init__( + self, + message: str, + status: int, + response_text: str, + ) -> None: + super().__init__(message) + self.status = status + self.response_text = response_text + + +class HeyGenClient: + """HeyGen Streaming API client.""" + + BASE_URL = "https://api.heygen.com/v1" + + def __init__( + self, + api_key: str, + session: aiohttp.ClientSession | None = None, + ) -> None: + self.api_key = api_key + self.session = session or aiohttp.ClientSession() + + async def request( + self, + path: str, + params: dict[str, Any], + ) -> Any: + """Make a POST request to the HeyGen API. + + Args: + path (str): API endpoint path. + params (Dict[str, Any]): JSON-serializable parameters. + + Returns: + Any: Parsed JSON response data. + + Raises: + APIError: If the API response is not successful. + + """ + url = f"{self.BASE_URL}{path}" + headers = { + "x-api-key": self.api_key, + "Content-Type": "application/json", + } + + async with self.session.post( + url, + json=params, + headers=headers, + ) as response: + if not response.ok: + response_text = await response.text() + logger.error( + "heygen api error", + response_text=response_text, + ) + raise HeygenAPIError( + f"API request failed with status {response.status}", + response.status, + response_text, + ) + json_data = await response.json() + return json_data.get("data") + + async def new_session( + self, + request_data: NewSessionRequest, + ) -> SessionResponse: + params = { + "avatar_name": request_data.avatarName, + "quality": request_data.quality, + "knowledge_base_id": request_data.knowledgeId, + "knowledge_base": request_data.knowledgeBase, + "voice": { + "voice_id": request_data.voice.voiceId + if request_data.voice + else None, + "rate": request_data.voice.rate + if request_data.voice + else None, + "emotion": request_data.voice.emotion + if request_data.voice + else None, + "elevenlabs_settings": ( + request_data.voice.elevenlabsSettings + if request_data.voice + else None + ), + }, + "language": request_data.language, + "version": "v2", + "video_encoding": "H264", + "source": "sdk", + "disable_idle_timeout": request_data.disableIdleTimeout, + } + session_info = await self.request( + "/streaming.new", + params, + ) + logger.info( + "heygen session info", + session_info=session_info, + ) + + return SessionResponse.model_validate(session_info) + + async def start_session(self, session_id: str) -> Any: + """Start the streaming session. + + Returns: + Any: Response data from the start session API call. + + """ + if not session_id: + raise ValueError( + "Session ID is not set. Call new_session first.", + ) + + params = { + "session_id": session_id, + } + return await self.request("/streaming.start", params) + + async def close(self) -> None: + """Close the aiohttp session.""" + await self.session.close() diff --git a/agent_video/pipecat/pipeline/src/functions/heygen_video_service.py b/agent_video/pipecat/pipeline/src/functions/heygen_video_service.py new file mode 100644 index 00000000..9a70c879 --- /dev/null +++ b/agent_video/pipecat/pipeline/src/functions/heygen_video_service.py @@ -0,0 +1,590 @@ +import asyncio +import base64 +import json +import os +import uuid + +import aiohttp +import websockets +from livekit import rtc +from livekit.rtc._proto.video_frame_pb2 import VideoBufferType +from loguru import logger +from pipecat.audio.utils import create_default_resampler +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + OutputImageRawFrame, + StartFrame, + StartInterruptionFrame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import AIService + + +class HeyGenVideoService(AIService): + """Class to send agent audio to HeyGen using the streaming audio input api.""" + + def __init__( + self, + *, + session_id: str, + session_token: str, + realtime_endpoint: str, + session: aiohttp.ClientSession, + livekit_room_url: str, + api_base_url: str = "https://api.heygen.com", + **kwargs: dict, + ) -> None: + super().__init__(**kwargs) + self._session_id = session_id + self._session_token = session_token + self._session = session + self._api_base_url = api_base_url + self._websocket = None + self._buffered_audio_duration_ms = 0 + self._event_id = None + self._realtime_endpoint = realtime_endpoint + self._livekit_room_url = livekit_room_url + self._livekit_room = None + self._video_task = None + self._audio_task = None + self._video_event = asyncio.Event() + self._video_event.set() + + # Constants + SAMPLE_RATE = 24000 + BUFFER_DURATION_THRESHOLD_MS = 2000 + BUFFER_COMMIT_THRESHOLD_MS = 4000 + + # AI Service class methods + async def start(self, frame: StartFrame) -> None: + logger.info("HeyGenVideoService starting") + await super().start(frame) + await self._ws_connect() + await self._livekit_connect() + + async def stop(self, frame: EndFrame) -> None: + logger.info("HeyGenVideoService stopping") + await super().stop(frame) + await self._stop() + + async def cancel(self, frame: CancelFrame) -> None: + logger.info("HeyGenVideoService canceling") + await super().cancel(frame) + await self._ws_disconnect() + await self._livekit_disconnect() + await self.stop_ttfb_metrics() + await self.stop_processing_metrics() + + async def interrupt( + self, frame: StartInterruptionFrame + ) -> None: + logger.info("HeyGenVideoService interrupting") + await super().interrupt(frame) + await self._interrupt() + + # websocket connection methods + async def _ws_connect(self) -> None: + """Connect to HeyGen websocket endpoint.""" + try: + logger.info("HeyGenVideoService ws connecting") + if self._websocket: + # assume connected + return + self._websocket = await websockets.connect( + uri=self._realtime_endpoint, + ) + self._receive_task = ( + self.get_event_loop().create_task( + self._ws_receive_task_handler(), + ) + ) + except websockets.exceptions.WebSocketException as e: + logger.error(f"{self} initialization error: {e}") + self._websocket = None + + async def _ws_disconnect(self) -> None: + """Disconnect from HeyGen websocket endpoint.""" + try: + if self._websocket: + await self._websocket.close() + except websockets.exceptions.WebSocketException as e: + logger.error(f"{self} disconnect error: {e}") + finally: + self._websocket = None + + async def _ws_receive_task_handler(self) -> None: + """Handle incoming messages from HeyGen websocket.""" + try: + while True: + message = await self._websocket.recv() + try: + parsed_message = json.loads(message) + await self._handle_ws_server_event( + parsed_message, + ) + except json.JSONDecodeError as e: + logger.error( + f"Failed to parse websocket message as JSON: {e}", + ) + continue + if message: + logger.info( + f"HeyGenVideoService ws received message: {message}", + ) + + except websockets.exceptions.WebSocketException as e: + logger.error( + f"Error receiving message from websocket: {e}", + ) + + async def _handle_ws_server_event(self, event: dict) -> None: + """Handle an event from HeyGen websocket.""" + event_type = event.get("type") + if event_type == "agent.state": + logger.info( + f"HeyGenVideoService ws received agent state: {event}", + ) + else: + logger.error( + f"HeyGenVideoService ws received unknown event: {event_type}", + ) + + async def _ws_send(self, message: dict) -> None: + """Send a message to HeyGen websocket.""" + try: + # logger.debug( + # f"HeyGenVideoService ws sending message: {message.get('type')}", + # ) + if self._websocket: + await self._websocket.send(json.dumps(message)) + else: + logger.error(f"{self} websocket not connected") + except websockets.exceptions.WebSocketException as e: + logger.error( + f"Error sending message to websocket: {e}", + ) + await self.push_error( + ErrorFrame( + error=f"Error sending client event: {e}", + fatal=True, + ), + ) + + async def _stop_session(self) -> None: + """Stop the current session.""" + try: + await self._ws_disconnect() + except websockets.exceptions.WebSocketException as e: + logger.error(f"{self} stop ws error: {e}") + url = f"{self._api_base_url}/v1/streaming.stop" + headers = { + "Content-Type": "application/json", + "accept": "application/json", + "x-api-key": os.getenv("HEYGEN_API_KEY"), + } + body = {"session_id": self._session_id} + async with self._session.post( + url, + headers=headers, + json=body, + ) as r: + r.raise_for_status() + + async def _interrupt(self) -> None: + """Interrupt the current session.""" + url = f"{self._api_base_url}/v1/streaming.interrupt" + headers = { + "Content-Type": "application/json", + "accept": "application/json", + "x-api-key": os.getenv("HEYGEN_API_KEY"), + } + body = {"session_id": self._session_id} + async with self._session.post( + url, + headers=headers, + json=body, + ) as r: + r.raise_for_status() + + # audio buffer methods + async def _send_audio( + self, + audio: bytes, + sample_rate: int, + event_id: str, + finish: bool = False, + ) -> None: + try: + if sample_rate != self.SAMPLE_RATE: + resampler = create_default_resampler() + audio = await resampler.resample( + audio, + sample_rate, + self.SAMPLE_RATE, + ) + # If sample_rate is already 16000, no resampling is needed + self._buffered_audio_duration_ms += ( + self._calculate_audio_duration_ms( + audio, + self.SAMPLE_RATE, + ) + ) + await self._agent_audio_buffer_append(audio) + + if ( + finish + and self._buffered_audio_duration_ms + < self.BUFFER_DURATION_THRESHOLD_MS + ): + await self._agent_audio_buffer_clear() + self._buffered_audio_duration_ms = 0 + + if ( + finish + or self._buffered_audio_duration_ms + > self.BUFFER_COMMIT_THRESHOLD_MS + ): + logger.info( + f"Audio buffer duration from buffer: {self._buffered_audio_duration_ms:.2f}ms", + ) + await self._agent_audio_buffer_commit() + self._buffered_audio_duration_ms = 0 + except Exception as e: + logger.error( + f"Error sending audio: {e}", + exc_info=True, + ) + + def _calculate_audio_duration_ms( + self, + audio: bytes, + sample_rate: int, + ) -> float: + # Each sample is 2 bytes (16-bit audio) + num_samples = len(audio) / 2 + return (num_samples / sample_rate) * 1000 + + async def _agent_audio_buffer_append( + self, + audio: bytes, + ) -> None: + audio_base64 = base64.b64encode(audio).decode("utf-8") + await self._ws_send( + { + "type": "agent.audio_buffer_append", + "audio": audio_base64, + "event_id": str(uuid.uuid4()), + }, + ) + + async def _agent_audio_buffer_clear(self) -> None: + await self._ws_send( + { + "type": "agent.audio_buffer_clear", + "event_id": str(uuid.uuid4()), + }, + ) + + async def _agent_audio_buffer_commit(self) -> None: + audio_base64 = base64.b64encode(b"\x00").decode("utf-8") + await self._ws_send( + { + "type": "agent.audio_buffer_commit", + "audio": audio_base64, + "event_id": str(uuid.uuid4()), + }, + ) + + # LiveKit connection methods + async def _process_audio_frames( + self, + stream: rtc.AudioStream, + ) -> None: + """Process audio frames from LiveKit stream.""" + frame_count = 0 + try: + logger.info("Starting audio frame processing...") + async for frame_event in stream: + frame_count += 1 + try: + audio_frame = frame_event.frame + # Convert audio to raw bytes + audio_data = bytes(audio_frame.data) + + # Create TTSAudioRawFrame + audio_frame = TTSAudioRawFrame( + audio=audio_data, + sample_rate=audio_frame.sample_rate, + num_channels=1, # HeyGen uses mono audio + ) + # Mark this frame as coming from LiveKit to avoid reprocessing + + await self.push_frame(audio_frame) + + except Exception as frame_error: + logger.error( + f"Error processing audio frame #{frame_count}: {frame_error!s}", + exc_info=True, + ) + except Exception as e: + logger.error( + f"Audio frame processing error after {frame_count} frames: {e!s}", + exc_info=True, + ) + finally: + logger.info( + f"Audio frame processing ended. Total frames processed: {frame_count}", + ) + + async def _process_video_frames( + self, + stream: rtc.VideoStream, + ) -> None: + """Process video frames from LiveKit stream.""" + frame_count = 0 + try: + logger.info("Starting video frame processing...") + async for frame_event in stream: + # Wait for video processing to be enabled + await self._video_event.wait() + + frame_count += 1 + try: + video_frame = frame_event.frame + + # Convert to RGB24 if not already + if video_frame.type != VideoBufferType.RGB24: + video_frame = video_frame.convert( + VideoBufferType.RGB24, + ) + + # Create frame with original dimensions + image_frame = OutputImageRawFrame( + image=bytes(video_frame.data), + size=( + video_frame.width, + video_frame.height, + ), + format="RGB", + ) + image_frame.pts = ( + frame_event.timestamp_us // 1000 + ) # Convert to milliseconds + + await self.push_frame(image_frame) + + except Exception as frame_error: + logger.error( + f"Error processing individual frame #{frame_count}: {frame_error!s}", + exc_info=True, + ) + except Exception as e: + logger.error( + f"Video frame processing error after {frame_count} frames: {e!s}", + exc_info=True, + ) + finally: + logger.info( + f"Video frame processing ended. Total frames processed: {frame_count}", + ) + + async def _livekit_connect(self) -> None: + """Connect to LiveKit room.""" + try: + logger.info( + f"HeyGenVideoService livekit connecting to room URL: {self._livekit_room_url}", + ) + self._livekit_room = rtc.Room() + + @self._livekit_room.on("participant_connected") + def on_participant_connected( + participant: rtc.RemoteParticipant, + ) -> None: + logger.info( + f"Participant connected - SID: {participant.sid}, Identity: {participant.identity}", + ) + for ( + track_pub + ) in participant.track_publications.values(): + logger.info( + f"Available track - SID: {track_pub.sid}, Kind: {track_pub.kind}, Name: {track_pub.name}", + ) + + @self._livekit_room.on("track_subscribed") + def on_track_subscribed( + track: rtc.Track, + publication: rtc.RemoteTrackPublication, + ) -> None: + logger.info( + f"Track subscribed - SID: {publication.sid}, Kind: {track.kind}, Source: {publication.source}", + ) + if track.kind == rtc.TrackKind.KIND_VIDEO: + logger.info( + f"Creating video stream processor for track: {publication.sid}", + ) + video_stream = rtc.VideoStream(track) + self._video_task = self.create_task( + self._process_video_frames(video_stream), + ) + elif track.kind == rtc.TrackKind.KIND_AUDIO: + logger.info( + f"Creating audio stream processor for track: {publication.sid}", + ) + audio_stream = rtc.AudioStream(track) + self._audio_task = self.create_task( + self._process_audio_frames(audio_stream), + ) + + @self._livekit_room.on("track_unsubscribed") + def on_track_unsubscribed( + track: rtc.Track, + publication: rtc.RemoteTrackPublication, + ) -> None: + logger.info( + f"Track unsubscribed - SID: {publication.sid}, Kind: {track.kind}", + ) + + @self._livekit_room.on("participant_disconnected") + def on_participant_disconnected( + participant: rtc.RemoteParticipant, + ) -> None: + logger.info( + f"Participant disconnected - SID: {participant.sid}, Identity: {participant.identity}", + ) + + logger.info( + "Attempting to connect to LiveKit room...", + ) + await self._livekit_room.connect( + self._livekit_room_url, + self._session_token, + ) + logger.info( + f"Successfully connected to LiveKit room: {self._livekit_room.name}", + ) + + # Log initial room state + logger.info(f"Room name: {self._livekit_room.name}") + logger.info( + f"Local participant SID: {self._livekit_room.local_participant.sid}", + ) + logger.info( + f"Number of remote participants: {len(self._livekit_room.remote_participants)}", + ) + + # Log existing participants and their tracks + for ( + participant + ) in self._livekit_room.remote_participants.values(): + logger.info( + f"Existing participant - SID: {participant.sid}, Identity: {participant.identity}", + ) + for ( + track_pub + ) in participant.track_publications.values(): + logger.info( + f"Existing track - SID: {track_pub.sid}, Kind: {track_pub.kind}, Name: {track_pub.name}", + ) + + except Exception as e: + logger.error( + f"LiveKit initialization error: {e!s}", + exc_info=True, + ) + self._livekit_room = None + + async def _livekit_disconnect(self) -> None: + """Disconnect from LiveKit room.""" + try: + logger.info("Starting LiveKit disconnect...") + if self._video_task: + logger.info("Canceling video processing task") + await self.cancel_task(self._video_task) + self._video_task = None + logger.info( + "Video processing task cancelled successfully", + ) + + if self._audio_task: + logger.info("Canceling audio processing task") + await self.cancel_task(self._audio_task) + self._audio_task = None + logger.info( + "Audio processing task cancelled successfully", + ) + + if self._livekit_room: + logger.info("Disconnecting from LiveKit room") + await self._livekit_room.disconnect() + self._livekit_room = None + logger.info( + "Successfully disconnected from LiveKit room", + ) + except Exception as e: + logger.error( + f"LiveKit disconnect error: {e!s}", + exc_info=True, + ) + + async def _stop(self) -> None: + """Stop all processing and disconnect.""" + if self._video_task: + await self.cancel_task(self._video_task) + self._video_task = None + if self._audio_task: + await self.cancel_task(self._audio_task) + self._audio_task = None + + await self._ws_disconnect() + await self._livekit_disconnect() + await self._stop_session() + + async def process_frame( + self, + frame: Frame, + direction: FrameDirection, + ) -> None: + await super().process_frame(frame, direction) + try: + if isinstance(frame, TTSStartedFrame): + logger.info("HeyGenVideoService TTS started") + await self.start_processing_metrics() + await self.start_ttfb_metrics() + self._event_id = str(uuid.uuid4()) + await self._agent_audio_buffer_clear() + elif isinstance(frame, TTSAudioRawFrame): + await self._send_audio( + frame.audio, + frame.sample_rate, + self._event_id, + finish=False, + ) + await self.stop_ttfb_metrics() + elif isinstance(frame, TTSStoppedFrame): + logger.info("HeyGenVideoService TTS stopped") + await self._send_audio( + b"\x00\x00", + self.SAMPLE_RATE, + self._event_id, + finish=True, + ) + await self.stop_processing_metrics() + self._event_id = None + elif isinstance(frame, StartInterruptionFrame): + await self._interrupt() + elif isinstance(frame, EndFrame | CancelFrame): + logger.info("HeyGenVideoService session ended") + await self._stop() + else: + await self.push_frame(frame, direction) + except Exception as e: + logger.error( + f"Error processing frame: {e}", + exc_info=True, + ) diff --git a/agent_video/pipecat/pipeline/src/functions/pipeline_audio.py b/agent_video/pipecat/pipeline/src/functions/pipeline_audio.py new file mode 100644 index 00000000..b543dd8d --- /dev/null +++ b/agent_video/pipecat/pipeline/src/functions/pipeline_audio.py @@ -0,0 +1,200 @@ +import os + +from dotenv import load_dotenv +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, +) +from pipecat.services.cartesia import CartesiaTTSService +from pipecat.services.deepgram import DeepgramSTTService +from pipecat.services.openai import OpenAILLMService +from pipecat.transports.services.daily import ( + DailyParams, + DailyTransport, +) +from pydantic import BaseModel +from restack_ai.function import ( + NonRetryableError, + function, + log, +) + +load_dotenv(override=True) + + +class PipecatPipelineAudioInput(BaseModel): + agent_name: str + agent_id: str + agent_run_id: str + daily_room_url: str + daily_room_token: str + + +VOICE_IDS = { + "system_1": os.getenv("CARTESIA_VOICE_ID"), # Restack voice + "system_2": os.getenv( + "CARTESIA_VOICE_ID_SYSTEM_2" + ), # Female voice +} + + +def get_agent_backend_host(engine_api_address: str) -> str: + if not engine_api_address: + return "http://localhost:9233" + if not engine_api_address.startswith("https://"): + return "https://" + engine_api_address + return engine_api_address + + +@function.defn(name="pipecat_pipeline_audio") +async def pipecat_pipeline_audio( + function_input: PipecatPipelineAudioInput, +) -> bool: + try: + engine_api_address = os.environ.get( + "RESTACK_ENGINE_API_ADDRESS", + ) + agent_backend_host = get_agent_backend_host( + engine_api_address, + ) + + log.info( + "Using RESTACK_ENGINE_API_ADDRESS", + agent_backend_host=agent_backend_host, + ) + + agent_url = f"{agent_backend_host}/stream/agents/{function_input.agent_name}/{function_input.agent_id}/{function_input.agent_run_id}" + log.info("Agent URL", agent_url=agent_url) + + transport = DailyTransport( + room_url=function_input.daily_room_url, + token=function_input.daily_room_token, + bot_name="bot", + params=DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + transcription_enabled=True, + camera_out_enabled=False, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + vad_audio_passthrough=True, + ), + ) + + stt = DeepgramSTTService( + api_key=os.getenv("DEEPGRAM_API_KEY"), + ) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id=VOICE_IDS["system_1"], + ) + + llm = OpenAILLMService( + api_key="pipecat-pipeline", + base_url=agent_url, + ) + + messages = [ + { + "role": "system", + "content": "", + }, + ] + + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator( + context + ) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, # STT + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output, + context_aggregator.assistant(), # Assistant spoken responses + ], + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + report_only_initial_ttfb=True, + ), + check_dangling_tasks=True, + ) + + @transport.event_handler( + "on_first_participant_joined", + ) + async def on_first_participant_joined( + transport: DailyTransport, + participant: dict, + ) -> None: + log.info( + "First participant joined", + participant=participant, + ) + + messages.append( + { + "role": "system", + "content": "Please introduce yourself to the user.", + } + ) + await task.queue_frames( + [context_aggregator.user().get_context_frame()] + ) + + @transport.event_handler("on_app_message") + async def on_app_message(transport, message, sender): + text = message.get("text") + log.info(f"Received {sender} message with {text}") + try: + tts.set_voice(VOICE_IDS["system_2"]) + await tts.say(f"SYSTEM TWO: {text}") + tts.set_voice(VOICE_IDS["system_1"]) + except Exception as e: + log.error("Error processing message", error=e) + + @transport.event_handler("on_participant_left") + async def on_participant_left( + transport: DailyTransport, + participant: dict, + reason: str, + ) -> None: + log.info( + "Participant left", + participant=participant, + reason=reason, + ) + await task.cancel() + + runner = PipelineRunner() + + try: + await runner.run(task) + except Exception as e: + log.error( + "Pipeline runner error, cancelling pipeline", + error=e, + ) + await task.cancel() + raise NonRetryableError( + "Pipeline runner error, cancelling pipeline" + ) from e + + return True + except Exception as e: + error_message = "Pipecat pipeline failed" + log.error(error_message, error=e) + raise NonRetryableError(error_message) from e diff --git a/agent_video/pipecat/pipeline/src/functions/pipeline_heygen.py b/agent_video/pipecat/pipeline/src/functions/pipeline_heygen.py new file mode 100644 index 00000000..a1808c0b --- /dev/null +++ b/agent_video/pipecat/pipeline/src/functions/pipeline_heygen.py @@ -0,0 +1,269 @@ +import os + +import aiohttp +from dotenv import load_dotenv +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, +) +from pipecat.services.cartesia import CartesiaTTSService +from pipecat.services.deepgram import DeepgramSTTService +from pipecat.services.openai import OpenAILLMService +from pipecat.transports.services.daily import ( + DailyParams, + DailyTransport, +) +from pydantic import BaseModel +from restack_ai.function import ( + NonRetryableError, + function, + log, +) + +from src.functions.heygen_client import ( + HeyGenClient, + NewSessionRequest, +) +from src.functions.heygen_video_service import HeyGenVideoService + +# from pipecat.frames.frames import EndFrame, TTSSpeakFrame + +load_dotenv(override=True) + + +class PipecatPipelineHeygenInput(BaseModel): + agent_name: str + agent_id: str + agent_run_id: str + daily_room_url: str + daily_room_token: str + + +def get_agent_backend_host(engine_api_address: str) -> str: + if not engine_api_address: + return "http://localhost:9233" + if not engine_api_address.startswith("https://"): + return "https://" + engine_api_address + return engine_api_address + + +@function.defn(name="pipecat_pipeline_heygen") +async def pipecat_pipeline_heygen( + function_input: PipecatPipelineHeygenInput, +) -> bool: + try: + async with aiohttp.ClientSession() as session: + engine_api_address = os.environ.get( + "RESTACK_ENGINE_API_ADDRESS", + ) + agent_backend_host = get_agent_backend_host( + engine_api_address, + ) + + log.info( + "Using RESTACK_ENGINE_API_ADDRESS", + agent_backend_host=agent_backend_host, + ) + + agent_url = f"{agent_backend_host}/stream/agents/{function_input.agent_name}/{function_input.agent_id}/{function_input.agent_run_id}" + log.info("Agent URL", agent_url=agent_url) + + transport = DailyTransport( + room_url=function_input.daily_room_url, + token=function_input.daily_room_token, + bot_name="HeyGen", + params=DailyParams( + audio_out_enabled=True, + camera_out_enabled=True, + camera_out_width=854, + camera_out_height=480, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + audio_out_sample_rate=HeyGenVideoService.SAMPLE_RATE, + ), + ) + + stt = DeepgramSTTService( + api_key=os.getenv("DEEPGRAM_API_KEY"), + ) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id=os.getenv("CARTESIA_VOICE_ID"), + sample_rate=HeyGenVideoService.SAMPLE_RATE, + ) + + llm = OpenAILLMService( + api_key="pipecat-pipeline", + base_url=agent_url, + ) + + messages = [ + { + "role": "system", + "content": "", + }, + ] + + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator( + context, + ) + + heygen_client = HeyGenClient( + api_key=os.getenv("HEYGEN_API_KEY"), + session=session, + ) + + session_response = await heygen_client.new_session( + NewSessionRequest( + avatarName="Bryan_IT_Sitting_public", + version="v2", + ), + ) + + await heygen_client.start_session( + session_response.session_id, + ) + + heygen_video_service = HeyGenVideoService( + session_id=session_response.session_id, + session_token=session_response.access_token, + session=session, + realtime_endpoint=session_response.realtime_endpoint, + livekit_room_url=session_response.url, + ) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, # STT + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + heygen_video_service, # HeyGen output layer + transport.output(), # Transport bot output + context_aggregator.assistant(), # Assistant spoken responses + ], + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + report_only_initial_ttfb=True, + audio_out_sample_rate=HeyGenVideoService.SAMPLE_RATE, + ), + check_dangling_tasks=True, + ) + + @transport.event_handler( + "on_first_participant_joined", + ) + async def on_first_participant_joined( + transport: DailyTransport, + participant: dict, + ) -> None: + log.info( + "First participant joined", + participant=participant, + ) + + messages.append( + { + "role": "system", + "content": "Please introduce yourself to the user. Keep it short and concise.", + }, + ) + + await task.queue_frames( + [ + context_aggregator.user().get_context_frame(), + ], + ) + + # @transport.event_handler("on_app_message") + # async def on_app_message(transport, message, sender): + # author = message.get("author") + # text = message.get("text") + + # log.debug(f"Received {sender} message from {author}: {text}") + + # try: + + # await tts.say(f"I received a message from {author}.") + + # await task.queue_frames([ + # TTSSpeakFrame(f"I received a message from {author}."), + # EndFrame(), + # ]) + + # log.info("tts say") + + # await tts.say(text) + + # log.info("llm push frame") + + # await llm.push_frame(TTSSpeakFrame(text)) + + # log.info("task queue frames") + + # await task.queue_frames([ + # TTSSpeakFrame(text), + # EndFrame(), + # ]) + + # log.info("task queue frames context_aggregator") + + # messages.append( + # { + # "role": "user", + # "content": f"Say {text}", + # }, + # ) + # await task.queue_frames( + # [ + # context_aggregator.user().get_context_frame(), + # ], + # ) + + # except Exception as e: + # log.error("Error processing message", error=e) + + @transport.event_handler("on_participant_left") + async def on_participant_left( + transport: DailyTransport, + participant: dict, + reason: str, + ) -> None: + log.info( + "Participant left", + participant=participant, + reason=reason, + ) + await task.cancel() + + runner = PipelineRunner() + + try: + await runner.run(task) + except Exception as e: + log.error( + "Pipeline runner error, cancelling pipeline", + error=e, + ) + await task.cancel() + raise NonRetryableError( + "Pipeline runner error, cancelling pipeline" + ) from e + + return True + except Exception as e: + error_message = "Pipecat pipeline failed" + log.error(error_message, error=e) + raise NonRetryableError(error_message) from e diff --git a/agent_video/pipecat/pipeline/src/functions/pipeline_tavus.py b/agent_video/pipecat/pipeline/src/functions/pipeline_tavus.py new file mode 100644 index 00000000..62715542 --- /dev/null +++ b/agent_video/pipecat/pipeline/src/functions/pipeline_tavus.py @@ -0,0 +1,235 @@ +import os +from collections.abc import Mapping +from typing import Any + +import aiohttp +from dotenv import load_dotenv +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, +) +from pipecat.services.cartesia import CartesiaTTSService +from pipecat.services.deepgram import DeepgramSTTService +from pipecat.services.openai import OpenAILLMService +from pipecat.transports.services.daily import ( + DailyParams, + DailyTransport, +) +from pydantic import BaseModel +from restack_ai.function import NonRetryableError, function, log + +from src.functions.tavus_video_service import TavusVideoService + +# from pipecat.frames.frames import EndFrame, TTSSpeakFrame + +load_dotenv(override=True) + + +class PipecatPipelineTavusInput(BaseModel): + agent_name: str + agent_id: str + agent_run_id: str + daily_room_url: str + + +@function.defn(name="pipecat_pipeline_tavus") +async def pipecat_pipeline_tavus( + function_input: PipecatPipelineTavusInput, +) -> bool: + try: + async with aiohttp.ClientSession() as session: + engine_api_address = os.environ.get( + "RESTACK_ENGINE_API_ADDRESS", + ) + if not engine_api_address: + agent_backend_host = "http://localhost:9233" + elif not engine_api_address.startswith("https://"): + agent_backend_host = ( + "https://" + engine_api_address + ) + else: + agent_backend_host = engine_api_address + + log.info( + "Using RESTACK_ENGINE_API_ADDRESS", + agent_backend_host=agent_backend_host, + ) + + agent_url = f"{agent_backend_host}/stream/agents/{function_input.agent_name}/{function_input.agent_id}/{function_input.agent_run_id}" + log.info("Agent URL", agent_url=agent_url) + + tavus = TavusVideoService( + api_key=os.getenv("TAVUS_API_KEY"), + replica_id=os.getenv("TAVUS_REPLICA_ID"), + session=session, + conversation_id=function_input.daily_room_url.split( + "/" + )[-1], + ) + + persona_name = await tavus.get_persona_name() + + transport = DailyTransport( + room_url=function_input.daily_room_url, + token=None, + bot_name=persona_name, + params=DailyParams( + audio_out_enabled=True, + camera_out_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + vad_audio_passthrough=True, + audio_out_sample_rate=TavusVideoService.SAMPLE_RATE, + ), + ) + + stt = DeepgramSTTService( + api_key=os.getenv("DEEPGRAM_API_KEY"), + ) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id=os.getenv("CARTESIA_VOICE_ID"), + sample_rate=TavusVideoService.SAMPLE_RATE, + ) + + llm = OpenAILLMService( + api_key="pipecat-pipeline", + base_url=agent_url, + ) + + messages = [ + { + "role": "system", + "content": "", + }, + ] + + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator( + context, + ) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, # STT + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + tavus, # Tavus output layer + transport.output(), # Transport bot output + context_aggregator.assistant(), # Assistant spoken responses + ], + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + report_only_initial_ttfb=True, + audio_out_sample_rate=TavusVideoService.SAMPLE_RATE, + ), + check_dangling_tasks=True, + ) + + @transport.event_handler("on_participant_joined") + async def on_participant_joined( + transport: DailyTransport, + participant: Mapping[str, Any], + ) -> None: + participant_id = participant.get("id") + if participant_id is None: + log.warning( + "Participant joined without an 'id', skipping update_subscriptions.", + ) + return + + # Ignore the Tavus replica's microphone + if ( + participant.get("info", {}).get( + "userName", + "", + ) + == persona_name + ): + log.debug( + f"Ignoring {participant_id}'s microphone", + ) + await transport.update_subscriptions( + participant_settings={ + str(participant_id): { + "media": { + "microphone": "unsubscribed", + }, + }, + }, + ) + else: + messages.append( + { + "role": "system", + "content": "Please introduce yourself to the user. Keep it short and concise.", + }, + ) + await task.queue_frames( + [ + context_aggregator.user().get_context_frame(), + ], + ) + + # @transport.event_handler("on_participant_joined") + # async def on_participant_joined(transport, participant): + # participant_name = participant.get("info", {}).get("userName", "") + # await task.queue_frames( + # [TTSSpeakFrame(f"Hello there, {participant_name}!"), EndFrame()] + # ) + + # @transport.event_handler("on_app_message") + # async def on_app_message(transport, message, sender): + # log.info(f"Received {sender} message: {message}") + # # author = message.get("author") + # # text = message.get("text") + # # log.debug(f"Received {sender} message from {author}: {text}") + # # await llm.push_frame(TTSSpeakFrame(text)) + # # await task.queue_frames( + # # [TTSSpeakFrame(text), EndFrame()] + # # ) + + @transport.event_handler("on_participant_left") + async def on_participant_left( + transport: DailyTransport, + participant: dict, + reason: str, + ) -> None: + log.info( + "Participant left", + participant=participant, + reason=reason, + ) + await task.cancel() + + runner = PipelineRunner() + + try: + await runner.run(task) + except Exception as e: + log.error( + "Pipeline runner error, cancelling pipeline", + error=e, + ) + await task.cancel() + raise NonRetryableError( + "Pipeline runner error, cancelling pipeline" + ) from e + + return True + except Exception as e: + error_message = "Pipecat pipeline failed" + log.error(error_message, error=e) + raise NonRetryableError(error_message) from e diff --git a/agent_video/pipecat/pipeline/src/functions/send_agent_event.py b/agent_video/pipecat/pipeline/src/functions/send_agent_event.py new file mode 100644 index 00000000..33edcc34 --- /dev/null +++ b/agent_video/pipecat/pipeline/src/functions/send_agent_event.py @@ -0,0 +1,31 @@ +from typing import Any + +from pydantic import BaseModel +from restack_ai.function import NonRetryableError, function + +from src.client import client + + +class SendAgentEventInput(BaseModel): + event_name: str + agent_id: str + run_id: str | None = None + event_input: dict[str, Any] | None = None + + +@function.defn() +async def send_agent_event( + function_input: SendAgentEventInput, +) -> str: + try: + return await client.send_agent_event( + event_name=function_input.event_name, + agent_id=function_input.agent_id, + run_id=function_input.run_id, + event_input=function_input.event_input, + ) + + except Exception as e: + raise NonRetryableError( + f"send_agent_event failed: {e}" + ) from e diff --git a/agent_video/pipecat/pipeline/src/functions/tavus_video_service.py b/agent_video/pipecat/pipeline/src/functions/tavus_video_service.py new file mode 100644 index 00000000..8489daad --- /dev/null +++ b/agent_video/pipecat/pipeline/src/functions/tavus_video_service.py @@ -0,0 +1,166 @@ +import base64 + +import aiohttp +from loguru import logger +from pipecat.audio.utils import create_default_resampler +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + Frame, + StartInterruptionFrame, + TransportMessageUrgentFrame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import AIService + + +class TavusVideoService(AIService): + """Class to send base64 encoded audio to Tavus""" + + def __init__( + self, + *, + api_key: str, + replica_id: str, + persona_id: str = "pipecat0", # Use `pipecat0` so that your TTS voice is used in place of the Tavus persona + conversation_id: str | None = None, + session: aiohttp.ClientSession, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self._api_key = api_key + self._replica_id = replica_id + self._persona_id = persona_id + self._session = session + self._conversation_id = conversation_id + + self._resampler = create_default_resampler() + + # Constants + SAMPLE_RATE = 24000 + + async def initialize(self) -> str: + url = "https://tavusapi.com/v2/conversations" + headers = { + "Content-Type": "application/json", + "x-api-key": self._api_key, + } + payload = { + "replica_id": self._replica_id, + "persona_id": self._persona_id, + } + async with self._session.post( + url, headers=headers, json=payload + ) as r: + r.raise_for_status() + response_json = await r.json() + + logger.debug( + f"TavusVideoService joined {response_json['conversation_url']}" + ) + self._conversation_id = response_json["conversation_id"] + return response_json["conversation_url"] + + def can_generate_metrics(self) -> bool: + return True + + async def get_persona_name(self) -> str: + url = ( + f"https://tavusapi.com/v2/personas/{self._persona_id}" + ) + headers = { + "Content-Type": "application/json", + "x-api-key": self._api_key, + } + async with self._session.get(url, headers=headers) as r: + r.raise_for_status() + response_json = await r.json() + + logger.debug( + f"TavusVideoService persona grabbed {response_json}" + ) + return response_json["persona_name"] + + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._end_conversation() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._end_conversation() + + async def _end_conversation(self) -> None: + url = f"https://tavusapi.com/v2/conversations/{self._conversation_id}/end" + headers = { + "Content-Type": "application/json", + "x-api-key": self._api_key, + } + async with self._session.post(url, headers=headers) as r: + r.raise_for_status() + + async def _encode_audio_and_send( + self, audio: bytes, in_rate: int, done: bool + ) -> None: + """Encodes audio to base64 and sends it to Tavus""" + if not done: + audio = await self._resampler.resample( + audio, in_rate, self.SAMPLE_RATE + ) + audio_base64 = base64.b64encode(audio).decode("utf-8") + logger.trace(f"{self}: sending {len(audio)} bytes") + await self._send_audio_message(audio_base64, done=done) + + async def process_frame( + self, frame: Frame, direction: FrameDirection + ): + await super().process_frame(frame, direction) + if isinstance(frame, TTSStartedFrame): + await self.start_processing_metrics() + await self.start_ttfb_metrics() + self._current_idx_str = str(frame.id) + elif isinstance(frame, TTSAudioRawFrame): + await self._encode_audio_and_send( + frame.audio, frame.sample_rate, done=False + ) + elif isinstance(frame, TTSStoppedFrame): + await self._encode_audio_and_send( + b"\x00", self.SAMPLE_RATE, done=True + ) + await self.stop_ttfb_metrics() + await self.stop_processing_metrics() + elif isinstance(frame, StartInterruptionFrame): + await self._send_interrupt_message() + else: + await self.push_frame(frame, direction) + + async def _send_interrupt_message(self) -> None: + transport_frame = TransportMessageUrgentFrame( + message={ + "message_type": "conversation", + "event_type": "conversation.interrupt", + "conversation_id": self._conversation_id, + } + ) + await self.push_frame(transport_frame) + + async def _send_audio_message( + self, audio_base64: str, done: bool + ) -> None: + transport_frame = TransportMessageUrgentFrame( + message={ + "message_type": "conversation", + "event_type": "conversation.echo", + "conversation_id": self._conversation_id, + "properties": { + "modality": "audio", + "inference_id": self._current_idx_str, + "audio": audio_base64, + "done": done, + "sample_rate": self.SAMPLE_RATE, + }, + } + ) + await self.push_frame(transport_frame) diff --git a/agent_video/pipecat/pipeline/src/services.py b/agent_video/pipecat/pipeline/src/services.py new file mode 100644 index 00000000..18624d01 --- /dev/null +++ b/agent_video/pipecat/pipeline/src/services.py @@ -0,0 +1,55 @@ +import asyncio +import logging +import webbrowser +from pathlib import Path + +from restack_ai.restack import ServiceOptions +from watchfiles import run_process + +from src.client import client +from src.functions.daily_delete_room import daily_delete_room +from src.functions.pipeline_audio import pipecat_pipeline_audio +from src.functions.pipeline_heygen import pipecat_pipeline_heygen +from src.functions.pipeline_tavus import pipecat_pipeline_tavus +from src.functions.send_agent_event import send_agent_event +from src.workflows.pipeline import PipelineWorkflow + + +async def main() -> None: + await client.start_service( + task_queue="pipeline", + workflows=[PipelineWorkflow], + functions=[ + pipecat_pipeline_tavus, + pipecat_pipeline_heygen, + pipecat_pipeline_audio, + daily_delete_room, + send_agent_event, + ], + options=ServiceOptions( + endpoint_group="agent_video", # used to locally show both agent and pipeline endpoints in UI + ), + ) + + +def run_services() -> None: + try: + asyncio.run(main()) + except KeyboardInterrupt: + logging.info( + "Service interrupted by user. Exiting gracefully.", + ) + + +def watch_services() -> None: + watch_path = Path.cwd() + logging.info( + "Watching %s and its subdirectories for changes...", + watch_path, + ) + webbrowser.open("http://localhost:5233") + run_process(watch_path, recursive=True, target=run_services) + + +if __name__ == "__main__": + run_services() diff --git a/agent_video/pipecat/pipeline/src/workflows/__init__.py b/agent_video/pipecat/pipeline/src/workflows/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agent_video/pipecat/pipeline/src/workflows/pipeline.py b/agent_video/pipecat/pipeline/src/workflows/pipeline.py new file mode 100644 index 00000000..612a009a --- /dev/null +++ b/agent_video/pipecat/pipeline/src/workflows/pipeline.py @@ -0,0 +1,157 @@ +from datetime import timedelta +from typing import Literal + +from pydantic import BaseModel +from restack_ai.workflow import ( + NonRetryableError, + import_functions, + log, + workflow, +) + +with import_functions(): + from src.functions.daily_delete_room import ( + DailyDeleteRoomInput, + daily_delete_room, + ) + from src.functions.pipeline_audio import ( + PipecatPipelineAudioInput, + pipecat_pipeline_audio, + ) + from src.functions.pipeline_heygen import ( + PipecatPipelineHeygenInput, + pipecat_pipeline_heygen, + ) + from src.functions.pipeline_tavus import ( + PipecatPipelineTavusInput, + pipecat_pipeline_tavus, + ) + from src.functions.send_agent_event import ( + SendAgentEventInput, + send_agent_event, + ) + + +class PipelineWorkflowInput(BaseModel): + video_service: Literal["tavus", "heygen", "audio"] + agent_name: str + agent_id: str + agent_run_id: str + daily_room_url: str | None = None + daily_room_token: str | None = None + + +@workflow.defn() +class PipelineWorkflow: + @workflow.run + async def run( + self, workflow_input: PipelineWorkflowInput + ) -> bool: + try: + if workflow_input.video_service == "tavus": + await workflow.step( + task_queue="pipeline", + function=pipecat_pipeline_tavus, + function_input=PipecatPipelineTavusInput( + agent_name=workflow_input.agent_name, + agent_id=workflow_input.agent_id, + agent_run_id=workflow_input.agent_run_id, + daily_room_url=workflow_input.daily_room_url, + ), + start_to_close_timeout=timedelta(minutes=20), + ) + + elif workflow_input.video_service == "heygen": + try: + await workflow.step( + task_queue="pipeline", + function=pipecat_pipeline_heygen, + function_input=PipecatPipelineHeygenInput( + agent_name=workflow_input.agent_name, + agent_id=workflow_input.agent_id, + agent_run_id=workflow_input.agent_run_id, + daily_room_url=workflow_input.daily_room_url, + daily_room_token=workflow_input.daily_room_token, + ), + start_to_close_timeout=timedelta( + minutes=20 + ), + ) + + except Exception as e: + log.error("Error heygen pipeline", error=e) + await workflow.step( + task_queue="pipeline", + function=daily_delete_room, + function_input=DailyDeleteRoomInput( + room_name=workflow_input.agent_run_id, + ), + ) + + await workflow.step( + task_queue="pipeline", + function=daily_delete_room, + function_input=DailyDeleteRoomInput( + room_name=workflow_input.agent_run_id, + ), + ) + + await workflow.step( + task_queue="pipeline", + function=send_agent_event, + function_input=SendAgentEventInput( + event_name="end", + agent_id=workflow_input.agent_id, + run_id=workflow_input.agent_run_id, + ), + ) + + elif workflow_input.video_service == "audio": + try: + await workflow.step( + task_queue="pipeline", + function=pipecat_pipeline_audio, + function_input=PipecatPipelineAudioInput( + agent_name=workflow_input.agent_name, + agent_id=workflow_input.agent_id, + agent_run_id=workflow_input.agent_run_id, + daily_room_url=workflow_input.daily_room_url, + daily_room_token=workflow_input.daily_room_token, + ), + ) + + except Exception as e: + log.error("Error audio pipeline", error=e) + await workflow.step( + task_queue="pipeline", + function=daily_delete_room, + function_input=DailyDeleteRoomInput( + room_name=workflow_input.agent_run_id, + ), + ) + + await workflow.step( + task_queue="pipeline", + function=daily_delete_room, + function_input=DailyDeleteRoomInput( + room_name=workflow_input.agent_run_id, + ), + ) + + await workflow.step( + task_queue="pipeline", + function=send_agent_event, + function_input=SendAgentEventInput( + event_name="end", + agent_id=workflow_input.agent_id, + run_id=workflow_input.agent_run_id, + ), + ) + + except Exception as e: + error_message = f"Error during pipecat_pipeline: {e}" + raise NonRetryableError(error_message) from e + else: + log.info("Pipecat pipeline done") + + return True diff --git a/agent_video/room_url.png b/agent_video/pipecat/room_url.png similarity index 100% rename from agent_video/room_url.png rename to agent_video/pipecat/room_url.png diff --git a/agent_video/tavus_replica.png b/agent_video/pipecat/tavus_replica.png similarity index 100% rename from agent_video/tavus_replica.png rename to agent_video/pipecat/tavus_replica.png diff --git a/agent_video/src/agents/agent.py b/agent_video/src/agents/agent.py deleted file mode 100644 index 895f15f3..00000000 --- a/agent_video/src/agents/agent.py +++ /dev/null @@ -1,62 +0,0 @@ -from datetime import timedelta - -from pydantic import BaseModel -from restack_ai.agent import NonRetryableError, agent, import_functions, log - -with import_functions(): - from src.functions.context_docs import context_docs - from src.functions.llm_chat import LlmChatInput, Message, llm_chat - -class MessagesEvent(BaseModel): - messages: list[Message] - - -class EndEvent(BaseModel): - end: bool - - -@agent.defn() -class AgentVideo: - def __init__(self) -> None: - self.end = False - self.messages: list[Message] = [] - - @agent.event - async def messages(self, messages_event: MessagesEvent) -> list[Message]: - log.info(f"Received message: {messages_event.messages}") - self.messages.extend(messages_event.messages) - - try: - assistant_message = await agent.step( - function=llm_chat, - function_input=LlmChatInput(messages=self.messages), - start_to_close_timeout=timedelta(seconds=120), - ) - except Exception as e: - error_message = f"llm_chat function failed: {e}" - raise NonRetryableError(error_message) from e - else: - self.messages.append(Message(role="assistant", content=str(assistant_message))) - return self.messages - - @agent.event - async def end(self, end: EndEvent) -> EndEvent: - log.info("Received end") - self.end = True - return end - - @agent.run - async def run(self) -> None: - try: - docs = await agent.step(function=context_docs) - except Exception as e: - error_message = f"context_docs function failed: {e}" - raise NonRetryableError(error_message) from e - else: - system_prompt=f""" - You are an interactive video assistant, your answers will be used in text to speech so try to keep answers short and concise so that interaction is seamless. - You can answer questions about the following documentation: - {docs} - """ - self.messages.append(Message(role="system", content=system_prompt)) - await agent.condition(lambda: self.end) diff --git a/agent_video/src/functions/llm_chat.py b/agent_video/src/functions/llm_chat.py deleted file mode 100644 index ce82e2a4..00000000 --- a/agent_video/src/functions/llm_chat.py +++ /dev/null @@ -1,52 +0,0 @@ -import os -from typing import TYPE_CHECKING, Literal - -from openai import OpenAI -from pydantic import BaseModel, Field -from restack_ai.function import NonRetryableError, function, stream_to_websocket - -from src.client import api_address - -if TYPE_CHECKING: - from openai.resources.chat.completions import ChatCompletionChunk, Stream - - -class Message(BaseModel): - role: Literal["system", "user", "assistant"] - content: str - - -class LlmChatInput(BaseModel): - system_content: str | None = None - model: str | None = None - messages: list[Message] = Field(default_factory=list) - stream: bool = True - - -@function.defn() -async def llm_chat(function_input: LlmChatInput) -> str: - try: - client = OpenAI( - base_url="https://ai.restack.io", api_key=os.environ.get("RESTACK_API_KEY") - ) - - if function_input.system_content: - # Insert the system message at the beginning - function_input.messages.insert( - 0, Message(role="system", content=function_input.system_content) - ) - - # Convert Message objects to dictionaries - messages_dicts = [message.model_dump() for message in function_input.messages] - # Get the streamed response from OpenAI API - response: Stream[ChatCompletionChunk] = client.chat.completions.create( - model=function_input.model or "gpt-4o-mini", - messages=messages_dicts, - stream=True, - ) - - return await stream_to_websocket(api_address=api_address, data=response) - - except Exception as e: - error_message = f"llm_chat function failed: {e}" - raise NonRetryableError(error_message) from e diff --git a/agent_video/src/functions/pipeline.py b/agent_video/src/functions/pipeline.py deleted file mode 100644 index 42972fdd..00000000 --- a/agent_video/src/functions/pipeline.py +++ /dev/null @@ -1,172 +0,0 @@ -# -# Copyright (c) 2024–2025, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# -import asyncio -import os -from collections.abc import Mapping -from typing import Any - -import aiohttp -from dotenv import load_dotenv -from pipecat.audio.vad.silero import SileroVADAnalyzer -from pipecat.pipeline.pipeline import Pipeline -from pipecat.pipeline.runner import PipelineRunner -from pipecat.pipeline.task import PipelineParams, PipelineTask -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.services.cartesia import CartesiaTTSService -from pipecat.services.deepgram import DeepgramSTTService -from pipecat.services.openai import OpenAILLMService -from pipecat.services.tavus import TavusVideoService -from pipecat.transports.services.daily import DailyParams, DailyTransport -from pydantic import BaseModel -from restack_ai.function import NonRetryableError, function, log - -load_dotenv(override=True) - - -class PipecatPipelineInput(BaseModel): - agent_name: str - agent_id: str - agent_run_id: str - - -@function.defn(name="pipecat_pipeline") -async def pipecat_pipeline(function_input: PipecatPipelineInput) -> str: - try: - async with aiohttp.ClientSession() as session: - engine_api_address = os.environ.get("RESTACK_ENGINE_API_ADDRESS") - if not engine_api_address: - agent_backend_host = "http://localhost:9233" - elif not engine_api_address.startswith("https://"): - agent_backend_host = "https://" + engine_api_address - else: - agent_backend_host = engine_api_address - - log.info( - "Using RESTACK_ENGINE_API_ADDRESS", - agent_backend_host=agent_backend_host, - ) - - agent_url = f"{agent_backend_host}/stream/agents/{function_input.agent_name}/{function_input.agent_id}/{function_input.agent_run_id}" - log.info("Agent URL", agent_url=agent_url) - - tavus = TavusVideoService( - api_key=os.getenv("TAVUS_API_KEY"), - replica_id=os.getenv("TAVUS_REPLICA_ID"), - session=session, - ) - - persona_name = await tavus.get_persona_name() - room_url = await tavus.initialize() - - transport = DailyTransport( - room_url=room_url, - token=None, - bot_name=persona_name, - params=DailyParams( - vad_enabled=True, - vad_analyzer=SileroVADAnalyzer(), - vad_audio_passthrough=True, - ), - ) - - stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) - - tts = CartesiaTTSService( - api_key=os.getenv("CARTESIA_API_KEY"), - voice_id=os.getenv("CARTESIA_VOICE_ID"), - ) - - llm = OpenAILLMService(api_key="pipecat-pipeline", base_url=agent_url) - - messages = [ - { - "role": "system", - "content": ( - "You are Philippe Brule, CTO and co-founder of restack : the entreprise-grade backend framework to help develoeprs build reliable and accurate AI agents. You live in San Francisco, California." - ), - }, - ] - - context = OpenAILLMContext(messages) - context_aggregator = llm.create_context_aggregator(context) - - pipeline = Pipeline( - [ - transport.input(), # Transport user input - stt, # STT - context_aggregator.user(), # User responses - llm, # LLM - tts, # TTS - tavus, # Tavus output layer - transport.output(), # Transport bot output - context_aggregator.assistant(), # Assistant spoken responses - ] - ) - - task = PipelineTask( - pipeline, - params=PipelineParams( - audio_in_sample_rate=16000, - audio_out_sample_rate=16000, - allow_interruptions=True, - enable_metrics=True, - enable_usage_metrics=True, - report_only_initial_ttfb=True, - ), - ) - - @transport.event_handler("on_participant_joined") - async def on_participant_joined( - transport: DailyTransport, participant: Mapping[str, Any] - ) -> None: - # Ignore the Tavus replica's microphone - if participant.get("info", {}).get("userName", "") == persona_name: - log.debug(f"Ignoring {participant['id']}'s microphone") - await transport.update_subscriptions( - participant_settings={ - participant["id"]: { - "media": {"microphone": "unsubscribed"}, - } - } - ) - else: - messages.append( - { - "role": "system", - "content": "Please introduce yourself to the user. Keep it short and concise.", - } - ) - await task.queue_frames( - [context_aggregator.user().get_context_frame()] - ) - - @transport.event_handler("on_participant_left") - async def on_participant_left(transport, participant, reason): - await task.cancel() - - runner = PipelineRunner() - - async def run_pipeline() -> None: - try: - await runner.run(task) - except Exception as e: - error_message = "Pipeline runner encountered an error, cancelling pipeline" - log.error(error_message, error=e) - # Cancel the pipeline task if an error occurs within the pipeline runner. - await task.cancel() - raise NonRetryableError(error_message) from e - - # Launch the pipeline runner as a background task so it doesn't block the return. - asyncio.create_task(run_pipeline()) - - log.info("Pipecat pipeline started", room_url=room_url) - - # Return the room_url immediately. - return room_url - except Exception as e: - error_message = "Pipecat pipeline failed" - log.error(error_message, error=e) - raise NonRetryableError(error_message) from e diff --git a/agent_video/src/services.py b/agent_video/src/services.py deleted file mode 100644 index c42a0c6f..00000000 --- a/agent_video/src/services.py +++ /dev/null @@ -1,43 +0,0 @@ -import asyncio -import logging -import webbrowser -from pathlib import Path - -from watchfiles import run_process - -from src.agents.agent import AgentVideo -from src.client import client -from src.functions.context_docs import context_docs -from src.functions.llm_chat import llm_chat -from src.functions.pipeline import pipecat_pipeline -from src.workflows.room import RoomWorkflow - - -async def main() -> None: - await client.start_service( - agents=[AgentVideo], - workflows=[RoomWorkflow], - functions=[ - llm_chat, - pipecat_pipeline, - context_docs, - ], - ) - - -def run_services() -> None: - try: - asyncio.run(main()) - except KeyboardInterrupt: - logging.info("Service interrupted by user. Exiting gracefully.") - - -def watch_services() -> None: - watch_path = Path.cwd() - logging.info("Watching %s and its subdirectories for changes...", watch_path) - webbrowser.open("http://localhost:5233") - run_process(watch_path, recursive=True, target=run_services) - - -if __name__ == "__main__": - run_services() diff --git a/agent_video/src/workflows/room.py b/agent_video/src/workflows/room.py deleted file mode 100644 index f2825de6..00000000 --- a/agent_video/src/workflows/room.py +++ /dev/null @@ -1,59 +0,0 @@ -from datetime import timedelta - -from pydantic import BaseModel -from restack_ai.workflow import ( - NonRetryableError, - ParentClosePolicy, - import_functions, - log, - workflow, - workflow_info, -) - -from src.agents.agent import AgentVideo - -with import_functions(): - from src.functions.pipeline import PipecatPipelineInput, pipecat_pipeline - - -class RoomWorkflowOutput(BaseModel): - room_url: str - - -@workflow.defn() -class RoomWorkflow: - @workflow.run - async def run(self) -> RoomWorkflowOutput: - agent_id = f"{workflow_info().workflow_id}-agent" - try: - agent = await workflow.child_start( - agent=AgentVideo, - agent_id=agent_id, - start_to_close_timeout=timedelta(minutes=20), - parent_close_policy=ParentClosePolicy.ABANDON, - ) - except Exception as e: - error_message = f"Error during child_start: {e}" - raise NonRetryableError(error_message) from e - else: - log.info("Agent started", agent=agent) - - try: - room_url = await workflow.step( - function=pipecat_pipeline, - function_input=PipecatPipelineInput( - agent_name=AgentVideo.__name__, - agent_id=agent.id, - agent_run_id=agent.run_id, - ), - start_to_close_timeout=timedelta(minutes=20), - ) - except Exception as e: - error_message = f"Error during pipecat_pipeline: {e}" - raise NonRetryableError(error_message) from e - else: - log.info("Pipecat pipeline started") - - log.info("RoomWorkflow completed", room_url=room_url) - - return RoomWorkflowOutput(room_url=room_url)