From 45d19d598d10d9863755a8b7d1a435160135c3b0 Mon Sep 17 00:00:00 2001 From: surajagarwal Date: Mon, 2 Sep 2024 23:16:46 +0530 Subject: [PATCH] Adding Transport capability for Twilio, along with the right acknowledgement of BOTSpeakingFrame. Currently full data is being pushed to Twilio. it sends back the BOTStoppedSpeaking Frame, which should not be the case, as bot was still speaking and handled by Twilio. --- src/pipecat/frames/frames.py | 8 +- src/pipecat/serializers/twilio.py | 74 ++++++++--- src/pipecat/transports/services/twilio.py | 146 ++++++++++++++++++++++ src/pipecat/utils/audio.py | 2 + 4 files changed, 211 insertions(+), 19 deletions(-) create mode 100644 src/pipecat/transports/services/twilio.py diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 13c2f53f..99e97ca5 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import Any, List, Mapping, Optional, Tuple +from typing import Any, List, Mapping, Optional, Tuple, Literal from dataclasses import dataclass, field @@ -500,3 +500,9 @@ class VADParamsUpdateFrame(ControlFrame): to be pushed upstream from RTVI processor. """ params: VADParams + +@dataclass +class MarkFrame(SystemFrame): + passed_name: str + seq_number: int | None = None + type: Literal["request", "response"] = "request" diff --git a/src/pipecat/serializers/twilio.py b/src/pipecat/serializers/twilio.py index 8836fcd6..0a2a705b 100644 --- a/src/pipecat/serializers/twilio.py +++ b/src/pipecat/serializers/twilio.py @@ -9,9 +9,10 @@ from pydantic import BaseModel -from pipecat.frames.frames import AudioRawFrame, Frame +from pipecat.frames.frames import AudioRawFrame, Frame, StartInterruptionFrame, MarkFrame from pipecat.serializers.base_serializer import FrameSerializer -from pipecat.utils.audio import ulaw_to_pcm, pcm_to_ulaw +from pipecat.utils.audio import change_audio_frame_rate, ulaw_to_pcm, pcm_to_ulaw +from loguru import logger class TwilioFrameSerializer(FrameSerializer): @@ -28,29 +29,59 @@ def __init__(self, stream_sid: str, params: InputParams = InputParams()): self._params = params def serialize(self, frame: Frame) -> str | bytes | None: - if not isinstance(frame, AudioRawFrame): - return None + if isinstance(frame, AudioRawFrame): + data = frame.audio - data = frame.audio + if frame.encoding == "mulaw": + if frame.sample_rate != self._params.twilio_sample_rate: + serialized_data = change_audio_frame_rate( + data, frame.sample_rate, self._params.twilio_sample_rate + ) + else: + serialized_data = data + elif frame.encoding == "pcm": + serialized_data = pcm_to_ulaw( + data, frame.sample_rate, self._params.twilio_sample_rate) + else: + raise ValueError(f"Unsupported encoding: {frame.encoding}") - serialized_data = pcm_to_ulaw(data, frame.sample_rate, self._params.twilio_sample_rate) - payload = base64.b64encode(serialized_data).decode("utf-8") - answer = { - "event": "media", - "streamSid": self._stream_sid, - "media": { - "payload": payload + payload = base64.b64encode(serialized_data).decode("utf-8") + answer = { + "event": "media", + "streamSid": self._stream_sid, + "media": { + "payload": payload + } } - } - return json.dumps(answer) + return json.dumps(answer) + + if isinstance(frame, StartInterruptionFrame): + answer = {"event": "clear", "streamSid": self._stream_sid} + return json.dumps(answer) + + if isinstance(frame, MarkFrame) and frame.type == "request": + answer = { + "event": "mark", + "streamSid": self._stream_sid, + "mark": {"name": frame.passed_name}, + # "sequenceNumber": frame.seq_number, + } + logger.info(f"Sending the mark frame with data: {answer}") + return json.dumps(answer) def deserialize(self, data: str | bytes) -> Frame | None: message = json.loads(data) - if message["event"] != "media": - return None - else: + if message["event"] == "mark": + mark_frame = MarkFrame( + seq_number=message["sequenceNumber"], + type="response", + passed_name=message["mark"]["name"], + ) + return mark_frame + + elif message["event"] == "media": payload_base64 = message["media"]["payload"] payload = base64.b64decode(payload_base64) @@ -61,5 +92,12 @@ def deserialize(self, data: str | bytes) -> Frame | None: audio_frame = AudioRawFrame( audio=deserialized_data, num_channels=1, - sample_rate=self._params.sample_rate) + sample_rate=self._params.sample_rate, + encoding="pcm" + ) return audio_frame + else: + return None + + + diff --git a/src/pipecat/transports/services/twilio.py b/src/pipecat/transports/services/twilio.py new file mode 100644 index 00000000..abefde8a --- /dev/null +++ b/src/pipecat/transports/services/twilio.py @@ -0,0 +1,146 @@ +import asyncio + +from pipecat.frames.frames import ( + Frame, + AudioRawFrame, + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + MarkFrame +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.transports.network.fastapi_websocket import ( + FastAPIWebsocketOutputTransport, + FastAPIWebsocketParams, + FastAPIWebsocketInputTransport, + FastAPIWebsocketCallbacks, +) +from pipecat.transports.base_transport import BaseTransport +from starlette.websockets import WebSocket, WebSocketState +from loguru import logger + + +class TwilioOutputTransport(FastAPIWebsocketOutputTransport): + def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, **kwargs): + super().__init__(websocket, params, **kwargs) + self.current_count = 0 + self.received_count = 0 + + async def _bot_started_speaking(self): + logger.debug( + f"Bot started speaking, Bot already speaking: {self._bot_speaking}" + ) + if not self._bot_speaking: + self._bot_speaking = True + await self._internal_push_frame( + BotStartedSpeakingFrame(), FrameDirection.UPSTREAM + ) + + async def _bot_stopped_speaking(self): + logger.info("Pushing the Marker at the end of stream") + self.current_count += 1 + mark_frame = MarkFrame( + passed_name=str(self.current_count), + type="request", + seq_number=self.current_count, + ) + payload = self._params.serializer.serialize(mark_frame) + if payload and self._websocket.client_state == WebSocketState.CONNECTED: + await self._websocket.send_text(payload) + logger.debug(f"Pushed the mark frame {payload}") + + async def _send_bot_stopped_speaking(self): + logger.info("Bot Stopped") + await self._internal_push_frame( + BotStoppedSpeakingFrame(), FrameDirection.UPSTREAM + ) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + logger.trace(f"Received the frame: {frame}, {direction}") + if ( + self._bot_speaking + and isinstance(frame, MarkFrame) + and frame.type == "response" + ): + logger.info( + "Received the mark frame and sending the Signal for Bot Being stopped from speaking." + ) + self.received_count += 1 + if self.received_count == self.current_count: + logger.info("Bot Stopped Speaking") + await self._send_bot_stopped_speaking() + self._bot_speaking = False + else: + await super().process_frame(frame, direction) + + +class TwilioInputTransport(FastAPIWebsocketInputTransport): + def __init__( + self, + websocket: WebSocket, + params: FastAPIWebsocketParams, + callbacks: FastAPIWebsocketCallbacks, + **kwargs, + ): + super().__init__(websocket, params, callbacks, **kwargs) + + async def _receive_messages(self): + async for message in self._websocket.iter_text(): + frame = self._params.serializer.deserialize(message) + + if not frame: + continue + + if isinstance(frame, MarkFrame): + logger.info( + f"Pushing the {frame} downstream from CustomTwilioInputProcessor" + ) + await self._internal_push_frame(frame, FrameDirection.DOWNSTREAM) + + if isinstance(frame, AudioRawFrame): + # logger.info(f"Pushing the audio frame {frame}") + await self.push_audio_frame(frame) + + await self._callbacks.on_client_disconnected(self._websocket) + + +class TwilioTransport(BaseTransport): + def __init__( + self, + websocket: WebSocket, + params: FastAPIWebsocketParams, + input_name: str | None = None, + output_name: str | None = None, + loop: asyncio.AbstractEventLoop | None = None, + ): + super().__init__(input_name=input_name, output_name=output_name, loop=loop) + self._params = params + + self._callbacks = FastAPIWebsocketCallbacks( + on_client_connected=self._on_client_connected, + on_client_disconnected=self._on_client_disconnected, + ) + + self._input = TwilioInputTransport( + websocket, self._params, self._callbacks, name=self._input_name + ) + + self._output = TwilioOutputTransport( + websocket, self._params, name=self._output_name + ) + + # Register supported handlers. The user will only be able to register + # these handlers. + self._register_event_handler("on_client_connected") + self._register_event_handler("on_client_disconnected") + + def input(self) -> FrameProcessor: + return self._input + + def output(self) -> FrameProcessor: + return self._output + + async def _on_client_connected(self, websocket): + await self._call_event_handler("on_client_connected", websocket) + + async def _on_client_disconnected(self, websocket): + await self._call_event_handler("on_client_disconnected", websocket) diff --git a/src/pipecat/utils/audio.py b/src/pipecat/utils/audio.py index 0764c6ab..2d375b4e 100644 --- a/src/pipecat/utils/audio.py +++ b/src/pipecat/utils/audio.py @@ -43,6 +43,8 @@ def ulaw_to_pcm(ulaw_bytes: bytes, in_sample_rate: int, out_sample_rate: int): return out_pcm_bytes +def change_audio_frame_rate(audio: bytes, in_sample_rate: int, out_sample_rate: int) -> bytes: + return audioop.ratecv(audio, 2, 1, in_sample_rate, out_sample_rate, None)[0] def pcm_to_ulaw(pcm_bytes: bytes, in_sample_rate: int, out_sample_rate: int): # Resample