From 44a349386c135df45beb07c32b82f6e17747c2c7 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Fri, 27 Sep 2024 10:43:49 -0400 Subject: [PATCH 1/4] Consolidate update frames classes into a single UpdateSettingsFrame class --- CHANGELOG.md | 3 + src/pipecat/frames/frames.py | 121 +++++----------------------- src/pipecat/services/ai_services.py | 38 +++++---- src/pipecat/services/anthropic.py | 63 +++++++++------ src/pipecat/services/google.py | 9 ++- src/pipecat/services/openai.py | 54 ++++++++----- src/pipecat/services/together.py | 62 ++++++++------ 7 files changed, 157 insertions(+), 193 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 37189eb47..0f489556c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,6 +93,9 @@ async def on_connected(processor): ### Changed +- Updated individual update settings frame classes into a single UpdateSettingsFrame + class for STT, LLM, and TTS. + - We now distinguish between input and output audio and image frames. We introduce `InputAudioRawFrame`, `OutputAudioRawFrame`, `InputImageRawFrame` and `OutputImageRawFrame` (and other subclasses of those). The input frames diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 273aad214..1b31b9c88 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -4,9 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import Any, List, Optional, Tuple - from dataclasses import dataclass, field +from typing import Any, List, Optional, Tuple from pipecat.clocks.base_clock import BaseClock from pipecat.metrics.metrics import MetricsData @@ -528,113 +527,35 @@ def __str__(self): @dataclass -class LLMModelUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM model.""" - - model: str - - -@dataclass -class LLMTemperatureUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM temperature.""" - - temperature: float - - -@dataclass -class LLMTopKUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM top_k.""" - - top_k: int - - -@dataclass -class LLMTopPUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM top_p.""" - - top_p: float - - -@dataclass -class LLMFrequencyPenaltyUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM frequency - penalty. - - """ - - frequency_penalty: float - - -@dataclass -class LLMPresencePenaltyUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM presence - penalty. +class LLMUpdateSettingsFrame(ControlFrame): + """A control frame containing a request to update LLM settings.""" - """ - - presence_penalty: float + model: Optional[str] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + max_tokens: Optional[int] = None + seed: Optional[int] = None + extra: dict = field(default_factory=dict) @dataclass -class LLMMaxTokensUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM max tokens.""" - - max_tokens: int - - -@dataclass -class LLMSeedUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM seed.""" - - seed: int - - -@dataclass -class LLMExtraUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM extra params.""" - - extra: dict - - -@dataclass -class TTSModelUpdateFrame(ControlFrame): - """A control frame containing a request to update the TTS model.""" - - model: str - - -@dataclass -class TTSVoiceUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new TTS voice.""" - - voice: str - - -@dataclass -class TTSLanguageUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new TTS language and - optional voice. - - """ - - language: Language - - -@dataclass -class STTModelUpdateFrame(ControlFrame): - """A control frame containing a request to update the STT model and optional - language. - - """ +class TTSUpdateSettingsFrame(ControlFrame): + """A control frame containing a request to update TTS settings.""" - model: str + model: Optional[str] = None + voice: Optional[str] = None + language: Optional[Language] = None @dataclass -class STTLanguageUpdateFrame(ControlFrame): - """A control frame containing a request to update to STT language.""" +class STTUpdateSettingsFrame(ControlFrame): + """A control frame containing a request to update STT settings.""" - language: Language + model: Optional[str] = None + language: Optional[Language] = None @dataclass diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 16280b024..79e52531d 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -7,10 +7,11 @@ import asyncio import io import wave - from abc import abstractmethod from typing import AsyncGenerator, List, Optional, Tuple +from loguru import logger + from pipecat.frames.frames import ( AudioRawFrame, CancelFrame, @@ -18,31 +19,26 @@ ErrorFrame, Frame, LLMFullResponseEndFrame, - STTLanguageUpdateFrame, - STTModelUpdateFrame, StartFrame, StartInterruptionFrame, + STTUpdateSettingsFrame, + TextFrame, TTSAudioRawFrame, - TTSLanguageUpdateFrame, - TTSModelUpdateFrame, TTSSpeakFrame, TTSStartedFrame, TTSStoppedFrame, - TTSVoiceUpdateFrame, - TextFrame, + TTSUpdateSettingsFrame, UserImageRequestFrame, VisionImageRawFrame, ) from pipecat.metrics.metrics import MetricsData +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.transcriptions.language import Language from pipecat.utils.audio import calculate_audio_volume from pipecat.utils.string import match_endofsentence from pipecat.utils.time import seconds_to_nanoseconds from pipecat.utils.utils import exp_smoothing -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext - -from loguru import logger class AIService(FrameProcessor): @@ -230,12 +226,13 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self.push_frame(frame, direction) elif isinstance(frame, TTSSpeakFrame): await self._push_tts_frames(frame.text) - elif isinstance(frame, TTSModelUpdateFrame): - await self.set_model(frame.model) - elif isinstance(frame, TTSVoiceUpdateFrame): - await self.set_voice(frame.voice) - elif isinstance(frame, TTSLanguageUpdateFrame): - await self.set_language(frame.language) + elif isinstance(frame, TTSUpdateSettingsFrame): + if frame.model is not None: + await self.set_model(frame.model) + if frame.voice is not None: + await self.set_voice(frame.voice) + if frame.language is not None: + await self.set_language(frame.language) else: await self.push_frame(frame, direction) @@ -408,10 +405,11 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # In this service we accumulate audio internally and at the end we # push a TextFrame. We don't really want to push audio frames down. await self.process_audio_frame(frame) - elif isinstance(frame, STTModelUpdateFrame): - await self.set_model(frame.model) - elif isinstance(frame, STTLanguageUpdateFrame): - await self.set_language(frame.language) + elif isinstance(frame, STTUpdateSettingsFrame): + if frame.model is not None: + await self.set_model(frame.model) + if frame.language is not None: + await self.set_language(frame.language) else: await self.push_frame(frame, direction) diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index 8b8e187ea..1c4cd284e 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -5,47 +5,47 @@ # import base64 -import json -import io import copy -from typing import Any, Dict, List, Optional +import io +import json +import re +from asyncio import CancelledError from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from loguru import logger from PIL import Image -from asyncio import CancelledError -import re from pydantic import BaseModel, Field from pipecat.frames.frames import ( Frame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, LLMEnablePromptCachingFrame, - LLMModelUpdateFrame, - TextFrame, - VisionImageRawFrame, - UserImageRequestFrame, - UserImageRawFrame, - LLMMessagesFrame, - LLMFullResponseStartFrame, LLMFullResponseEndFrame, - FunctionCallResultFrame, - FunctionCallInProgressFrame, + LLMFullResponseStartFrame, + LLMMessagesFrame, + LLMUpdateSettingsFrame, StartInterruptionFrame, + TextFrame, + UserImageRawFrame, + UserImageRequestFrame, + VisionImageRawFrame, ) from pipecat.metrics.metrics import LLMTokenUsage -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import LLMService +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantContextAggregator, + LLMUserContextAggregator, +) from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) -from pipecat.processors.aggregators.llm_response import ( - LLMUserContextAggregator, - LLMAssistantContextAggregator, -) - -from loguru import logger +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import LLMService try: - from anthropic import AsyncAnthropic, NOT_GIVEN, NotGiven + from anthropic import NOT_GIVEN, AsyncAnthropic, NotGiven except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error( @@ -293,9 +293,20 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # UserImageRawFrames coming through the pipeline and add them # to the context. context = AnthropicLLMContext.from_image_frame(frame) - elif isinstance(frame, LLMModelUpdateFrame): - logger.debug(f"Switching LLM model to: [{frame.model}]") - self.set_model_name(frame.model) + elif isinstance(frame, LLMUpdateSettingsFrame): + if frame.model is not None: + logger.debug(f"Switching LLM model to: [{frame.model}]") + self.set_model_name(frame.model) + if frame.max_tokens is not None: + await self.set_max_tokens(frame.max_tokens) + if frame.temperature is not None: + await self.set_temperature(frame.temperature) + if frame.top_k is not None: + await self.set_top_k(frame.top_k) + if frame.top_p is not None: + await self.set_top_p(frame.top_p) + if frame.extra: + await self.set_extra(frame.extra) elif isinstance(frame, LLMEnablePromptCachingFrame): logger.debug(f"Setting enable prompt caching to: [{frame.enable}]") self._enable_prompt_caching_beta = frame.enable diff --git a/src/pipecat/services/google.py b/src/pipecat/services/google.py index 38af3e41f..53efd8c17 100644 --- a/src/pipecat/services/google.py +++ b/src/pipecat/services/google.py @@ -17,7 +17,7 @@ LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesFrame, - LLMModelUpdateFrame, + LLMUpdateSettingsFrame, TextFrame, TTSAudioRawFrame, TTSStartedFrame, @@ -136,9 +136,10 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): context = OpenAILLMContext.from_messages(frame.messages) elif isinstance(frame, VisionImageRawFrame): context = OpenAILLMContext.from_image_frame(frame) - elif isinstance(frame, LLMModelUpdateFrame): - logger.debug(f"Switching LLM model to: [{frame.model}]") - self._create_client(frame.model) + elif isinstance(frame, LLMUpdateSettingsFrame): + if frame.model is not None: + logger.debug(f"Switching LLM model to: [{frame.model}]") + self.set_model_name(frame.model) else: await self.push_frame(frame, direction) diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 47bee5ec1..a830b65a8 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -4,38 +4,39 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import aiohttp import base64 import io import json -import httpx - from dataclasses import dataclass - from typing import Any, AsyncGenerator, Dict, List, Literal, Optional + +import aiohttp +import httpx +from loguru import logger +from PIL import Image from pydantic import BaseModel, Field from pipecat.frames.frames import ( ErrorFrame, Frame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesFrame, - LLMModelUpdateFrame, + LLMUpdateSettingsFrame, + StartInterruptionFrame, + TextFrame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, - TextFrame, URLImageRawFrame, VisionImageRawFrame, - FunctionCallResultFrame, - FunctionCallInProgressFrame, - StartInterruptionFrame, ) from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_response import ( - LLMUserContextAggregator, LLMAssistantContextAggregator, + LLMUserContextAggregator, ) from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, @@ -44,12 +45,14 @@ from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_services import ImageGenService, LLMService, TTSService -from PIL import Image - -from loguru import logger - try: - from openai import AsyncOpenAI, AsyncStream, DefaultAsyncHttpxClient, BadRequestError, NOT_GIVEN + from openai import ( + NOT_GIVEN, + AsyncOpenAI, + AsyncStream, + BadRequestError, + DefaultAsyncHttpxClient, + ) from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam except ModuleNotFoundError as e: logger.error(f"Exception: {e}") @@ -280,9 +283,22 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): context = OpenAILLMContext.from_messages(frame.messages) elif isinstance(frame, VisionImageRawFrame): context = OpenAILLMContext.from_image_frame(frame) - elif isinstance(frame, LLMModelUpdateFrame): - logger.debug(f"Switching LLM model to: [{frame.model}]") - self.set_model_name(frame.model) + elif isinstance(frame, LLMUpdateSettingsFrame): + if frame.model is not None: + logger.debug(f"Switching LLM model to: [{frame.model}]") + self.set_model_name(frame.model) + if frame.frequency_penalty is not None: + await self.set_frequency_penalty(frame.frequency_penalty) + if frame.presence_penalty is not None: + await self.set_presence_penalty(frame.presence_penalty) + if frame.seed is not None: + await self.set_seed(frame.seed) + if frame.temperature is not None: + await self.set_temperature(frame.temperature) + if frame.top_p is not None: + await self.set_top_p(frame.top_p) + if frame.extra: + await self.set_extra(frame.extra) else: await self.push_frame(frame, direction) @@ -464,7 +480,7 @@ async def process_frame(self, frame, direction): await self._push_aggregation() else: logger.warning( - f"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id" + "FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id" ) self._function_call_in_progress = None self._function_call_result = None diff --git a/src/pipecat/services/together.py b/src/pipecat/services/together.py index b1365bc69..981aa6de2 100644 --- a/src/pipecat/services/together.py +++ b/src/pipecat/services/together.py @@ -7,37 +7,36 @@ import json import re import uuid -from pydantic import BaseModel, Field - -from typing import Any, Dict, List, Optional -from dataclasses import dataclass from asyncio import CancelledError +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from loguru import logger +from pydantic import BaseModel, Field from pipecat.frames.frames import ( Frame, - LLMModelUpdateFrame, - TextFrame, - UserImageRequestFrame, - LLMMessagesFrame, - LLMFullResponseStartFrame, - LLMFullResponseEndFrame, - FunctionCallResultFrame, FunctionCallInProgressFrame, + FunctionCallResultFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMMessagesFrame, + LLMUpdateSettingsFrame, StartInterruptionFrame, + TextFrame, + UserImageRequestFrame, ) from pipecat.metrics.metrics import LLMTokenUsage -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import LLMService +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantContextAggregator, + LLMUserContextAggregator, +) from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) -from pipecat.processors.aggregators.llm_response import ( - LLMUserContextAggregator, - LLMAssistantContextAggregator, -) - -from loguru import logger +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import LLMService try: from together import AsyncTogether @@ -188,7 +187,7 @@ async def _process_context(self, context: OpenAILLMContext): if chunk.choices[0].finish_reason == "eos" and accumulating_function_call: await self._extract_function_call(context, function_call_accumulator) - except CancelledError as e: + except CancelledError: # todo: implement token counting estimates for use when the user interrupts a long generation # we do this in the anthropic.py service raise @@ -206,9 +205,24 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): context = frame.context elif isinstance(frame, LLMMessagesFrame): context = TogetherLLMContext.from_messages(frame.messages) - elif isinstance(frame, LLMModelUpdateFrame): - logger.debug(f"Switching LLM model to: [{frame.model}]") - self.set_model_name(frame.model) + elif isinstance(frame, LLMUpdateSettingsFrame): + if frame.model is not None: + logger.debug(f"Switching LLM model to: [{frame.model}]") + self.set_model_name(frame.model) + if frame.frequency_penalty is not None: + await self.set_frequency_penalty(frame.frequency_penalty) + if frame.max_tokens is not None: + await self.set_max_tokens(frame.max_tokens) + if frame.presence_penalty is not None: + await self.set_presence_penalty(frame.presence_penalty) + if frame.temperature is not None: + await self.set_temperature(frame.temperature) + if frame.top_k is not None: + await self.set_top_k(frame.top_k) + if frame.top_p is not None: + await self.set_top_p(frame.top_p) + if frame.extra: + await self.set_extra(frame.extra) else: await self.push_frame(frame, direction) @@ -338,7 +352,7 @@ async def process_frame(self, frame, direction): await self._push_aggregation() else: logger.warning( - f"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id" + "FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id" ) self._function_call_in_progress = None self._function_call_result = None From 7fe118ce639aeac6bd84c7c2e721f3cf6dd298e0 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Fri, 27 Sep 2024 11:22:03 -0400 Subject: [PATCH 2/4] Align use of language param across TTS services --- src/pipecat/services/azure.py | 17 ++++++++++------- src/pipecat/services/elevenlabs.py | 10 +++++----- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/pipecat/services/azure.py b/src/pipecat/services/azure.py index c8fa095ab..a1349cefe 100644 --- a/src/pipecat/services/azure.py +++ b/src/pipecat/services/azure.py @@ -41,7 +41,10 @@ SpeechRecognizer, SpeechSynthesizer, ) - from azure.cognitiveservices.speech.audio import AudioStreamFormat, PushAudioInputStream + from azure.cognitiveservices.speech.audio import ( + AudioStreamFormat, + PushAudioInputStream, + ) from azure.cognitiveservices.speech.dialog import AudioConfig from openai import AsyncAzureOpenAI except ModuleNotFoundError as e: @@ -73,7 +76,7 @@ def create_client(self, api_key=None, base_url=None, **kwargs): class AzureTTSService(TTSService): class InputParams(BaseModel): emphasis: Optional[str] = None - language_code: Optional[str] = "en-US" + language: Optional[str] = "en-US" pitch: Optional[str] = None rate: Optional[str] = "1.05" role: Optional[str] = None @@ -105,7 +108,7 @@ def can_generate_metrics(self) -> bool: def _construct_ssml(self, text: str) -> str: ssml = ( - f"" f"" @@ -155,9 +158,9 @@ async def set_emphasis(self, emphasis: str): logger.debug(f"Setting TTS emphasis to: [{emphasis}]") self._params.emphasis = emphasis - async def set_language_code(self, language_code: str): - logger.debug(f"Setting TTS language code to: [{language_code}]") - self._params.language_code = language_code + async def set_language(self, language: str): + logger.debug(f"Setting TTS language code to: [{language}]") + self._params.language = language async def set_pitch(self, pitch: str): logger.debug(f"Setting TTS pitch to: [{pitch}]") @@ -187,7 +190,7 @@ async def set_params(self, **kwargs): valid_params = { "voice": self.set_voice, "emphasis": self.set_emphasis, - "language_code": self.set_language_code, + "language_code": self.set_language, "pitch": self.set_pitch, "rate": self.set_rate, "role": self.set_role, diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index 79d90bc58..ca4713f5f 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -72,7 +72,7 @@ def calculate_word_times( class ElevenLabsTTSService(AsyncWordTTSService): class InputParams(BaseModel): - language_code: Optional[str] = None + language: Optional[str] = None output_format: Literal["pcm_16000", "pcm_22050", "pcm_24000", "pcm_44100"] = "pcm_16000" optimize_streaming_latency: Optional[str] = None stability: Optional[float] = None @@ -229,13 +229,13 @@ async def _connect(self): if self._params.optimize_streaming_latency: url += f"&optimize_streaming_latency={self._params.optimize_streaming_latency}" - # language_code can only be used with the 'eleven_turbo_v2_5' model - if self._params.language_code: + # language can only be used with the 'eleven_turbo_v2_5' model + if self._params.language: if model == "eleven_turbo_v2_5": - url += f"&language_code={self._params.language_code}" + url += f"&language_code={self._params.language}" else: logger.debug( - f"Language code [{self._params.language_code}] not applied. Language codes can only be used with the 'eleven_turbo_v2_5' model." + f"Language code [{self._params.language}] not applied. Language codes can only be used with the 'eleven_turbo_v2_5' model." ) self._websocket = await websockets.connect(url) From d7555609fd752e27ea34879356613d50d4055e8b Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Fri, 27 Sep 2024 11:57:50 -0400 Subject: [PATCH 3/4] Add TTS update settings options --- src/pipecat/frames/frames.py | 12 +++++- src/pipecat/services/ai_services.py | 62 ++++++++++++++++++++++++++++- 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 1b31b9c88..8059b904b 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -5,7 +5,7 @@ # from dataclasses import dataclass, field -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union from pipecat.clocks.base_clock import BaseClock from pipecat.metrics.metrics import MetricsData @@ -548,6 +548,16 @@ class TTSUpdateSettingsFrame(ControlFrame): model: Optional[str] = None voice: Optional[str] = None language: Optional[Language] = None + speed: Optional[Union[str, float]] = None + emotion: Optional[List[str]] = None + engine: Optional[str] = None + pitch: Optional[str] = None + rate: Optional[str] = None + volume: Optional[str] = None + emphasis: Optional[str] = None + style: Optional[str] = None + style_degree: Optional[str] = None + role: Optional[str] = None @dataclass diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 79e52531d..1cb91d6a2 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -8,7 +8,7 @@ import io import wave from abc import abstractmethod -from typing import AsyncGenerator, List, Optional, Tuple +from typing import AsyncGenerator, List, Optional, Tuple, Union from loguru import logger @@ -170,6 +170,46 @@ async def set_voice(self, voice: str): async def set_language(self, language: Language): pass + @abstractmethod + async def set_speed(self, speed: Union[str, float]): + pass + + @abstractmethod + async def set_emotion(self, emotion: List[str]): + pass + + @abstractmethod + async def set_engine(self, engine: str): + pass + + @abstractmethod + async def set_pitch(self, pitch: str): + pass + + @abstractmethod + async def set_rate(self, rate: str): + pass + + @abstractmethod + async def set_volume(self, volume: str): + pass + + @abstractmethod + async def set_emphasis(self, emphasis: str): + pass + + @abstractmethod + async def set_style(self, style: str): + pass + + @abstractmethod + async def set_style_degree(self, style_degree: str): + pass + + @abstractmethod + async def set_role(self, role: str): + pass + # Converts the text to audio. @abstractmethod async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: @@ -233,6 +273,26 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self.set_voice(frame.voice) if frame.language is not None: await self.set_language(frame.language) + if frame.speed is not None: + await self.set_speed(frame.speed) + if frame.emotion is not None: + await self.set_emotion(frame.emotion) + if frame.engine is not None: + await self.set_engine(frame.engine) + if frame.pitch is not None: + await self.set_pitch(frame.pitch) + if frame.rate is not None: + await self.set_rate(frame.rate) + if frame.volume is not None: + await self.set_volume(frame.volume) + if frame.emphasis is not None: + await self.set_emphasis(frame.emphasis) + if frame.style is not None: + await self.set_style(frame.style) + if frame.style_degree is not None: + await self.set_style_degree(frame.style_degree) + if frame.role is not None: + await self.set_role(frame.role) else: await self.push_frame(frame, direction) From 1f77863aef18c1f81d181d05e03924cbdf855e2a Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 30 Sep 2024 12:45:32 -0400 Subject: [PATCH 4/4] Code review feedback --- src/pipecat/services/ai_services.py | 66 ++++++++++++++++------------- src/pipecat/services/anthropic.py | 29 +++++++------ src/pipecat/services/openai.py | 33 ++++++++------- src/pipecat/services/together.py | 37 ++++++++-------- 4 files changed, 90 insertions(+), 75 deletions(-) diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 1cb91d6a2..ba78b24f8 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -248,6 +248,34 @@ async def _push_tts_frames(self, text: str): # interrupted, the text is not added to the assistant context. await self.push_frame(TextFrame(text)) + async def _update_tts_settings(self, frame: TTSUpdateSettingsFrame): + if frame.model is not None: + await self.set_model(frame.model) + if frame.voice is not None: + await self.set_voice(frame.voice) + if frame.language is not None: + await self.set_language(frame.language) + if frame.speed is not None: + await self.set_speed(frame.speed) + if frame.emotion is not None: + await self.set_emotion(frame.emotion) + if frame.engine is not None: + await self.set_engine(frame.engine) + if frame.pitch is not None: + await self.set_pitch(frame.pitch) + if frame.rate is not None: + await self.set_rate(frame.rate) + if frame.volume is not None: + await self.set_volume(frame.volume) + if frame.emphasis is not None: + await self.set_emphasis(frame.emphasis) + if frame.style is not None: + await self.set_style(frame.style) + if frame.style_degree is not None: + await self.set_style_degree(frame.style_degree) + if frame.role is not None: + await self.set_role(frame.role) + async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -267,32 +295,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): elif isinstance(frame, TTSSpeakFrame): await self._push_tts_frames(frame.text) elif isinstance(frame, TTSUpdateSettingsFrame): - if frame.model is not None: - await self.set_model(frame.model) - if frame.voice is not None: - await self.set_voice(frame.voice) - if frame.language is not None: - await self.set_language(frame.language) - if frame.speed is not None: - await self.set_speed(frame.speed) - if frame.emotion is not None: - await self.set_emotion(frame.emotion) - if frame.engine is not None: - await self.set_engine(frame.engine) - if frame.pitch is not None: - await self.set_pitch(frame.pitch) - if frame.rate is not None: - await self.set_rate(frame.rate) - if frame.volume is not None: - await self.set_volume(frame.volume) - if frame.emphasis is not None: - await self.set_emphasis(frame.emphasis) - if frame.style is not None: - await self.set_style(frame.style) - if frame.style_degree is not None: - await self.set_style_degree(frame.style_degree) - if frame.role is not None: - await self.set_role(frame.role) + await self._update_tts_settings(frame) else: await self.push_frame(frame, direction) @@ -454,6 +457,12 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Returns transcript as a string""" pass + async def _update_stt_settings(self, frame: STTUpdateSettingsFrame): + if frame.model is not None: + await self.set_model(frame.model) + if frame.language is not None: + await self.set_language(frame.language) + async def process_audio_frame(self, frame: AudioRawFrame): await self.process_generator(self.run_stt(frame.audio)) @@ -466,10 +475,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # push a TextFrame. We don't really want to push audio frames down. await self.process_audio_frame(frame) elif isinstance(frame, STTUpdateSettingsFrame): - if frame.model is not None: - await self.set_model(frame.model) - if frame.language is not None: - await self.set_language(frame.language) + await self._update_stt_settings(frame) else: await self.push_frame(frame, direction) diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index 1c4cd284e..bc91e4e16 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -279,6 +279,21 @@ async def _process_context(self, context: OpenAILLMContext): cache_read_input_tokens=cache_read_input_tokens, ) + async def _update_settings(self, frame: LLMUpdateSettingsFrame): + if frame.model is not None: + logger.debug(f"Switching LLM model to: [{frame.model}]") + self.set_model_name(frame.model) + if frame.max_tokens is not None: + await self.set_max_tokens(frame.max_tokens) + if frame.temperature is not None: + await self.set_temperature(frame.temperature) + if frame.top_k is not None: + await self.set_top_k(frame.top_k) + if frame.top_p is not None: + await self.set_top_p(frame.top_p) + if frame.extra: + await self.set_extra(frame.extra) + async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -294,19 +309,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # to the context. context = AnthropicLLMContext.from_image_frame(frame) elif isinstance(frame, LLMUpdateSettingsFrame): - if frame.model is not None: - logger.debug(f"Switching LLM model to: [{frame.model}]") - self.set_model_name(frame.model) - if frame.max_tokens is not None: - await self.set_max_tokens(frame.max_tokens) - if frame.temperature is not None: - await self.set_temperature(frame.temperature) - if frame.top_k is not None: - await self.set_top_k(frame.top_k) - if frame.top_p is not None: - await self.set_top_p(frame.top_p) - if frame.extra: - await self.set_extra(frame.extra) + await self._update_settings(frame) elif isinstance(frame, LLMEnablePromptCachingFrame): logger.debug(f"Setting enable prompt caching to: [{frame.enable}]") self._enable_prompt_caching_beta = frame.enable diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index a830b65a8..f0892b9ca 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -273,6 +273,23 @@ async def _handle_function_call(self, context, tool_call_id, function_name, argu arguments=arguments, ) + async def _update_settings(self, frame: LLMUpdateSettingsFrame): + if frame.model is not None: + logger.debug(f"Switching LLM model to: [{frame.model}]") + self.set_model_name(frame.model) + if frame.frequency_penalty is not None: + await self.set_frequency_penalty(frame.frequency_penalty) + if frame.presence_penalty is not None: + await self.set_presence_penalty(frame.presence_penalty) + if frame.seed is not None: + await self.set_seed(frame.seed) + if frame.temperature is not None: + await self.set_temperature(frame.temperature) + if frame.top_p is not None: + await self.set_top_p(frame.top_p) + if frame.extra: + await self.set_extra(frame.extra) + async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -284,21 +301,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): elif isinstance(frame, VisionImageRawFrame): context = OpenAILLMContext.from_image_frame(frame) elif isinstance(frame, LLMUpdateSettingsFrame): - if frame.model is not None: - logger.debug(f"Switching LLM model to: [{frame.model}]") - self.set_model_name(frame.model) - if frame.frequency_penalty is not None: - await self.set_frequency_penalty(frame.frequency_penalty) - if frame.presence_penalty is not None: - await self.set_presence_penalty(frame.presence_penalty) - if frame.seed is not None: - await self.set_seed(frame.seed) - if frame.temperature is not None: - await self.set_temperature(frame.temperature) - if frame.top_p is not None: - await self.set_top_p(frame.top_p) - if frame.extra: - await self.set_extra(frame.extra) + await self._update_settings(frame) else: await self.push_frame(frame, direction) diff --git a/src/pipecat/services/together.py b/src/pipecat/services/together.py index 981aa6de2..e4068ecfc 100644 --- a/src/pipecat/services/together.py +++ b/src/pipecat/services/together.py @@ -128,6 +128,25 @@ async def set_extra(self, extra: Dict[str, Any]): logger.debug(f"Switching LLM extra to: [{extra}]") self._extra = extra + async def _update_settings(self, frame: LLMUpdateSettingsFrame): + if frame.model is not None: + logger.debug(f"Switching LLM model to: [{frame.model}]") + self.set_model_name(frame.model) + if frame.frequency_penalty is not None: + await self.set_frequency_penalty(frame.frequency_penalty) + if frame.max_tokens is not None: + await self.set_max_tokens(frame.max_tokens) + if frame.presence_penalty is not None: + await self.set_presence_penalty(frame.presence_penalty) + if frame.temperature is not None: + await self.set_temperature(frame.temperature) + if frame.top_k is not None: + await self.set_top_k(frame.top_k) + if frame.top_p is not None: + await self.set_top_p(frame.top_p) + if frame.extra: + await self.set_extra(frame.extra) + async def _process_context(self, context: OpenAILLMContext): try: await self.push_frame(LLMFullResponseStartFrame()) @@ -206,23 +225,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): elif isinstance(frame, LLMMessagesFrame): context = TogetherLLMContext.from_messages(frame.messages) elif isinstance(frame, LLMUpdateSettingsFrame): - if frame.model is not None: - logger.debug(f"Switching LLM model to: [{frame.model}]") - self.set_model_name(frame.model) - if frame.frequency_penalty is not None: - await self.set_frequency_penalty(frame.frequency_penalty) - if frame.max_tokens is not None: - await self.set_max_tokens(frame.max_tokens) - if frame.presence_penalty is not None: - await self.set_presence_penalty(frame.presence_penalty) - if frame.temperature is not None: - await self.set_temperature(frame.temperature) - if frame.top_k is not None: - await self.set_top_k(frame.top_k) - if frame.top_p is not None: - await self.set_top_p(frame.top_p) - if frame.extra: - await self.set_extra(frame.extra) + await self._update_settings(frame) else: await self.push_frame(frame, direction)