diff --git a/agents-core/vision_agents/core/agents/agents.py b/agents-core/vision_agents/core/agents/agents.py index f6c5e609..a7da6362 100644 --- a/agents-core/vision_agents/core/agents/agents.py +++ b/agents-core/vision_agents/core/agents/agents.py @@ -177,6 +177,7 @@ async def simple_response( """ Overwrite simple_response if you want to change how the Agent class calls the LLM """ + logger.info("asking LLM to reply to %s", text) with self.tracer.start_as_current_span("simple_response") as span: response = await self.llm.simple_response( text=text, processors=self.processors, participant=participant diff --git a/agents-core/vision_agents/core/stt/__init__.py b/agents-core/vision_agents/core/stt/__init__.py index 3d16ab45..fcc04dd4 100644 --- a/agents-core/vision_agents/core/stt/__init__.py +++ b/agents-core/vision_agents/core/stt/__init__.py @@ -1,3 +1,4 @@ from .stt import STT +from .events import TranscriptResponse -__all__ = ["STT"] +__all__ = ["STT", "TranscriptResponse"] diff --git a/agents-core/vision_agents/core/stt/events.py b/agents-core/vision_agents/core/stt/events.py index 8f016960..a10210b1 100644 --- a/agents-core/vision_agents/core/stt/events.py +++ b/agents-core/vision_agents/core/stt/events.py @@ -4,22 +4,47 @@ @dataclass -class STTTranscriptEvent(PluginBaseEvent): - """Event emitted when a complete transcript is available.""" - - type: str = field(default='plugin.stt_transcript', init=False) - text: str = "" +class TranscriptResponse: confidence: Optional[float] = None language: Optional[str] = None processing_time_ms: Optional[float] = None audio_duration_ms: Optional[float] = None model_name: Optional[str] = None - words: Optional[list[dict[str, Any]]] = None + other: Optional[dict] = None + +@dataclass +class STTTranscriptEvent(PluginBaseEvent): + """Event emitted when a complete transcript is available.""" + + type: str = field(default='plugin.stt_transcript', init=False) + text: str = "" + response: TranscriptResponse = field(default_factory=TranscriptResponse) is_final: bool = True def __post_init__(self): if not self.text: raise ValueError("Transcript text cannot be empty") + + # Convenience properties for backward compatibility + @property + def confidence(self) -> Optional[float]: + return self.response.confidence + + @property + def language(self) -> Optional[str]: + return self.response.language + + @property + def processing_time_ms(self) -> Optional[float]: + return self.response.processing_time_ms + + @property + def audio_duration_ms(self) -> Optional[float]: + return self.response.audio_duration_ms + + @property + def model_name(self) -> Optional[str]: + return self.response.model_name @dataclass @@ -28,13 +53,29 @@ class STTPartialTranscriptEvent(PluginBaseEvent): type: str = field(default='plugin.stt_partial_transcript', init=False) text: str = "" - confidence: Optional[float] = None - language: Optional[str] = None - processing_time_ms: Optional[float] = None - audio_duration_ms: Optional[float] = None - model_name: Optional[str] = None - words: Optional[list[dict[str, Any]]] = None + response: TranscriptResponse = field(default_factory=TranscriptResponse) is_final: bool = False + + # Convenience properties for backward compatibility + @property + def confidence(self) -> Optional[float]: + return self.response.confidence + + @property + def language(self) -> Optional[str]: + return self.response.language + + @property + def processing_time_ms(self) -> Optional[float]: + return self.response.processing_time_ms + + @property + def audio_duration_ms(self) -> Optional[float]: + return self.response.audio_duration_ms + + @property + def model_name(self) -> Optional[str]: + return self.response.model_name @dataclass diff --git a/agents-core/vision_agents/core/stt/stt.py b/agents-core/vision_agents/core/stt/stt.py index 18f10714..c3b21af7 100644 --- a/agents-core/vision_agents/core/stt/stt.py +++ b/agents-core/vision_agents/core/stt/stt.py @@ -1,17 +1,13 @@ import abc import logging -import time import uuid -from typing import Optional, Dict, Any, Tuple, List, Union +from typing import Optional, Dict, Any, Union from getstream.video.rtc.track_util import PcmData from ..edge.types import Participant -from vision_agents.core.events import ( - PluginInitializedEvent, - PluginClosedEvent, -) from vision_agents.core.events.manager import EventManager from . import events +from .events import TranscriptResponse logger = logging.getLogger(__name__) @@ -20,270 +16,94 @@ class STT(abc.ABC): """ Abstract base class for Speech-to-Text implementations. - This class provides a standardized interface for STT implementations with consistent - event emission patterns and error handling. + Subclasses implement this and have to call + - _emit_partial_transcript_event + - _emit_transcript_event + - _emit_error_event for temporary errors - Events: - - transcript: Emitted when a complete transcript is available. - Args: text (str), user_metadata (dict), metadata (dict) - - partial_transcript: Emitted when a partial transcript is available. - Args: text (str), user_metadata (dict), metadata (dict) - - error: Emitted when an error occurs during transcription. - Args: error (Exception) - - Standard Error Handling: - - All implementations should catch exceptions in _process_audio_impl and emit error events - - Use _emit_error_event() helper for consistent error emission - - Log errors with appropriate context using the logger - - Standard Event Emission: - - Use _emit_transcript_event() and _emit_partial_transcript_event() helpers - - Include processing time and audio duration in metadata when available - - Maintain consistent metadata structure across implementations + process_audio is currently called every 20ms. The integration with turn keeping could be improved """ + closed: bool = False def __init__( self, - sample_rate: int = 16000, - *, provider_name: Optional[str] = None, ): - """ - Initialize the STT service. - - Args: - sample_rate: The sample rate of the audio to process, in Hz. - provider_name: Name of the STT provider (e.g., "deepgram", "moonshine") - """ - - self._track = None - self.sample_rate = sample_rate - self._is_closed = False self.session_id = str(uuid.uuid4()) self.provider_name = provider_name or self.__class__.__name__ + self.events = EventManager() self.events.register_events_from_module(events, ignore_not_compatible=True) - self.events.send(PluginInitializedEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - plugin_type="STT", - provider=self.provider_name, - configuration={"sample_rate": sample_rate}, - )) - - def _validate_pcm_data(self, pcm_data: PcmData) -> bool: - """ - Validate PCM data input for processing. - - Args: - pcm_data: The PCM audio data to validate. - - Returns: - True if the data is valid, False otherwise. - """ - - if not hasattr(pcm_data, "samples") or pcm_data.samples is None: - logger.warning("PCM data has no samples") - return False - - if not hasattr(pcm_data, "sample_rate") or pcm_data.sample_rate <= 0: - logger.warning("PCM data has invalid sample rate") - return False - - # Check if samples are empty - if hasattr(pcm_data.samples, "__len__") and len(pcm_data.samples) == 0: - logger.debug("Received empty audio samples") - return False - - return True - def _emit_transcript_event( self, text: str, - user_metadata: Optional[Union[Dict[str, Any], Participant]], - metadata: Dict[str, Any], + participant: Optional[Union[Dict[str, Any], Participant]], + response: TranscriptResponse, ): """ Emit a final transcript event with structured data. Args: text: The transcribed text. - user_metadata: User-specific metadata. - metadata: Transcription metadata (processing time, confidence, etc.). + participant: Participant metadata. + response: Transcription response metadata. """ self.events.send(events.STTTranscriptEvent( session_id=self.session_id, plugin_name=self.provider_name, text=text, - user_metadata=user_metadata, - confidence=metadata.get("confidence"), - language=metadata.get("language"), - processing_time_ms=metadata.get("processing_time_ms"), - audio_duration_ms=metadata.get("audio_duration_ms"), - model_name=metadata.get("model_name"), - words=metadata.get("words"), + user_metadata=participant, + response=response, )) def _emit_partial_transcript_event( self, text: str, - user_metadata: Optional[Union[Dict[str, Any], Participant]], - metadata: Dict[str, Any], + participant: Optional[Union[Dict[str, Any], Participant]], + response: TranscriptResponse, ): """ Emit a partial transcript event with structured data. Args: text: The partial transcribed text. - user_metadata: User-specific metadata. - metadata: Transcription metadata (processing time, confidence, etc.). + participant: Participant metadata. + response: Transcription response metadata. """ self.events.send(events.STTPartialTranscriptEvent( session_id=self.session_id, plugin_name=self.provider_name, text=text, - user_metadata=user_metadata, - confidence=metadata.get("confidence"), - language=metadata.get("language"), - processing_time_ms=metadata.get("processing_time_ms"), - audio_duration_ms=metadata.get("audio_duration_ms"), - model_name=metadata.get("model_name"), - words=metadata.get("words"), + user_metadata=participant, + response=response, )) def _emit_error_event( self, error: Exception, context: str = "", - user_metadata: Optional[Union[Dict[str, Any], Participant]] = None, + participant: Optional[Union[Dict[str, Any], Participant]] = None, ): """ - Emit an error event with structured data. - - Args: - error: The exception that occurred. - context: Additional context about where the error occurred. - user_metadata: User-specific metadata. + Emit an error event. Note this should only be emitted for temporary errors. + Permanent errors due to config etc should be directly raised """ self.events.send(events.STTErrorEvent( session_id=self.session_id, plugin_name=self.provider_name, error=error, context=context, - user_metadata=user_metadata, + user_metadata=participant, error_code=getattr(error, "error_code", None), is_recoverable=not isinstance(error, (SystemExit, KeyboardInterrupt)), )) + @abc.abstractmethod async def process_audio( - self, pcm_data: PcmData, participant: Optional[Participant] = None + self, pcm_data: PcmData, participant: Optional[Participant] = None, ): - """ - Process audio data for transcription and emit appropriate events. - - Args: - pcm_data: The PCM audio data to process. - user_metadata: Additional metadata about the user or session. - """ - if self._is_closed: - logger.debug("Ignoring audio processing request - STT is closed") - return - - # Validate input data - if not self._validate_pcm_data(pcm_data): - logger.warning("Invalid PCM data received, skipping processing") - return - - try: - # Process the audio data using the implementation-specific method - audio_duration_ms = ( - pcm_data.duration * 1000 if hasattr(pcm_data, "duration") else None - ) - logger.debug( - "Processing audio chunk", - extra={ - "duration_ms": audio_duration_ms, - "has_user_metadata": participant is not None, - }, - ) - - start_time = time.time() - results = await self._process_audio_impl(pcm_data, participant) - processing_time = time.time() - start_time - - # If no results were returned, just return - if not results: - logger.debug( - "No speech detected in audio", - extra={ - "processing_time_ms": processing_time * 1000, - "audio_duration_ms": audio_duration_ms, - }, - ) - return - - # Process each result and emit the appropriate event - for is_final, text, metadata in results: - # Ensure metadata includes processing time if not already present - if "processing_time_ms" not in metadata: - metadata["processing_time_ms"] = processing_time * 1000 - - if is_final: - self._emit_transcript_event(text, participant, metadata) - else: - self._emit_partial_transcript_event(text, participant, metadata) - - except Exception as e: - # Emit any errors that occur during processing - self._emit_error_event(e, "audio processing", participant) - - @abc.abstractmethod - async def _process_audio_impl( - self, pcm_data: PcmData, user_metadata: Optional[Union[Dict[str, Any], Participant]] = None - ) -> Optional[List[Tuple[bool, str, Dict[str, Any]]]]: - """ - Implementation-specific method to process audio data. - - This method must be implemented by all STT providers and should handle the core - transcription logic. The base class handles event emission and error handling. - - Args: - pcm_data: The PCM audio data to process. Guaranteed to be valid by base class. - user_metadata: Additional metadata about the user or session. - - Returns: - optional list[tuple[bool, str, dict]] | None - • synchronous providers: a list of results. - • asynchronous providers: None (they emit events themselves). - - Notes: - Implementations must not both emit events and return non-empty results, - or duplicate events will be produced. - Exceptions should bubble up; process_audio() will catch them - and emit a single "error" event. - """ pass - @abc.abstractmethod async def close(self): - """ - Close the STT service and release any resources. - - Implementations should: - - Set self._is_closed = True - - Clean up any background tasks or connections - - Release any allocated resources - - Log the closure appropriately - """ - if not self._is_closed: - self._is_closed = True - - # Emit closure event - self.events.send(PluginClosedEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - plugin_type="STT", - provider=self.provider_name, - cleanup_successful=True, - )) + self.closed = True diff --git a/conftest.py b/conftest.py index e3688d97..da17b102 100644 --- a/conftest.py +++ b/conftest.py @@ -16,9 +16,9 @@ from vision_agents.core.edge.types import PcmData from vision_agents.core.stt.events import STTTranscriptEvent, STTErrorEvent - load_dotenv() + class STTSession: """Helper class for testing STT implementations. @@ -65,6 +65,14 @@ async def wait_for_result(self, timeout: float = 30.0): # Wait for an event await asyncio.wait_for(self._event.wait(), timeout=timeout) + + def get_full_transcript(self) -> str: + """Get full transcription text from all transcript events. + + Returns: + Combined text from all transcripts + """ + return " ".join(t.text for t in self.transcripts) def get_assets_dir(): diff --git a/docs/ai/instructions/ai-stt.md b/docs/ai/instructions/ai-stt.md index e6361a85..2dc7dc83 100644 --- a/docs/ai/instructions/ai-stt.md +++ b/docs/ai/instructions/ai-stt.md @@ -2,22 +2,42 @@ ```python from vision_agents.core import stt +from vision_agents.core.stt.events import TranscriptResponse class MySTT(stt.STT): def __init__( self, api_key: Optional[str] = None, - sample_rate: int = 48000, - client: Optional[AsyncDeepgramClient] = None, + client: Optional[MyClient] = None, ): - super().__init__(sample_rate=sample_rate) + super().__init__(provider_name="my_stt") + # be sure to allow the passing of the client object + # if client is not passed, create one + # pass the most common settings for the client in the init (like api key) - async def _process_audio_impl( - self, pcm_data: PcmData, user_metadata: Optional[Union[Dict[str, Any], Participant]] = None - ) -> Optional[List[Tuple[bool, str, Dict[str, Any]]]]: - pass + async def process_audio( + self, + pcm_data: PcmData, + participant: Optional[Participant] = None, + ): + parts = self.client.stt(pcm_data, stream=True) + full_text = "" + for part in parts: + response = TranscriptResponse( + confidence=0.9, + language='en', + processing_time_ms=300, + audio_duration_ms=2000, + other={} + ) + # parts that aren't finished + self._emit_partial_transcript_event(part, participant, response) + full_text += part + + # the full text + self._emit_transcript_event(full_text, participant, response) ``` diff --git a/examples/01_simple_agent_example/simple_agent_example.py b/examples/01_simple_agent_example/simple_agent_example.py index 6e65e382..9be6e8c3 100644 --- a/examples/01_simple_agent_example/simple_agent_example.py +++ b/examples/01_simple_agent_example/simple_agent_example.py @@ -33,8 +33,7 @@ async def start_agent() -> None: # Create a call call = agent.edge.client.video.call("default", str(uuid4())) - # Open the demo UI - await agent.edge.open_demo(call) + # Have the agent join the call/room with await agent.join(call): @@ -54,6 +53,8 @@ async def start_agent() -> None: # run till the call ends # await agent.say("Hello, how are you?") # await asyncio.sleep(5) + # Open the demo UI + await agent.edge.open_demo(call) await agent.simple_response("tell me something interesting in a short sentence") await agent.finish() diff --git a/plugins/deepgram/tests/test_deepgram_stt.py b/plugins/deepgram/tests/test_deepgram_stt.py new file mode 100644 index 00000000..2b6376d4 --- /dev/null +++ b/plugins/deepgram/tests/test_deepgram_stt.py @@ -0,0 +1,34 @@ +import pytest + +from vision_agents.plugins import deepgram +from conftest import STTSession + + +class TestDeepgramSTT: + """Integration tests for Deepgram STT""" + + @pytest.fixture + async def stt(self): + """Create and manage Deepgram STT lifecycle""" + stt = deepgram.STT() + try: + yield stt + finally: + await stt.close() + + @pytest.mark.integration + async def test_transcribe_mia_audio_48khz(self, stt, mia_audio_48khz): + # Create session to collect transcripts and errors + session = STTSession(stt) + + # Process the audio + await stt.process_audio(mia_audio_48khz) + + # Wait for result + await session.wait_for_result(timeout=30.0) + assert not session.errors + + # Verify transcript + full_transcript = session.get_full_transcript() + assert "forgotten treasures" in full_transcript.lower() + diff --git a/plugins/deepgram/tests/test_stt.py b/plugins/deepgram/tests/test_stt.py deleted file mode 100644 index e69de29b..00000000 diff --git a/plugins/deepgram/vision_agents/plugins/deepgram/stt.py b/plugins/deepgram/vision_agents/plugins/deepgram/stt.py index 01a3d32d..28873ade 100644 --- a/plugins/deepgram/vision_agents/plugins/deepgram/stt.py +++ b/plugins/deepgram/vision_agents/plugins/deepgram/stt.py @@ -3,7 +3,7 @@ import logging import os import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional import numpy as np import websockets @@ -20,11 +20,11 @@ from getstream.video.rtc.track_util import PcmData from vision_agents.core import stt +from vision_agents.core.stt import TranscriptResponse from .utils import generate_silence -if TYPE_CHECKING: - from vision_agents.core.edge.types import Participant +from vision_agents.core.edge.types import Participant logger = logging.getLogger(__name__) @@ -50,7 +50,6 @@ def __init__( self, api_key: Optional[str] = None, options: Optional[dict] = None, - sample_rate: int = 48000, language: str = "en-US", interim_results: bool = True, client: Optional[AsyncDeepgramClient] = None, @@ -70,7 +69,7 @@ def __init__( connection_timeout: Time to wait for the Deepgram connection to be established. """ - super().__init__(sample_rate=sample_rate) + super().__init__(provider_name="deepgram") # If no API key was provided, check for DEEPGRAM_API_KEY in environment if api_key is None: @@ -86,12 +85,13 @@ def __init__( client if client is not None else AsyncDeepgramClient(api_key=api_key) ) self.dg_connection: Optional[AsyncV1SocketClient] = None + self.sample_rate = 48000 self.options = options or { "model": "nova-2", "language": language, "encoding": "linear16", - "sample_rate": sample_rate, + "sample_rate": self.sample_rate, "channels": 1, "interim_results": interim_results, } @@ -101,7 +101,7 @@ def __init__( # Generate a silence audio to use as keep-alive message self._keep_alive_data = generate_silence( - sample_rate=sample_rate, duration_ms=10 + sample_rate=self.sample_rate, duration_ms=10 ) self._keep_alive_interval = keep_alive_interval @@ -121,7 +121,7 @@ async def start(self): """ Start the main task establishing the Deepgram connection and processing the events. """ - if self._is_closed: + if self.closed: logger.warning("Cannot setup connection - Deepgram instance is closed") return None @@ -178,15 +178,8 @@ async def started(self): ) async def close(self): + await super().close() """Close the Deepgram connection and clean up resources.""" - if self._is_closed: - logger.debug("Deepgram STT service already closed") - return - - logger.info("Closing Deepgram STT service") - self._is_closed = True - - # Close the Deepgram connection if it exists if self.dg_connection: logger.debug("Closing Deepgram connection") try: @@ -225,20 +218,17 @@ async def _on_message( # Check if this is a final result is_final = transcript.get("is_final", False) - # Create metadata with useful information - metadata = { - "confidence": alternatives[0].get("confidence", 0), - "words": alternatives[0].get("words", []), - "is_final": is_final, - "channel_index": transcript.get("channel_index", 0), - } + # Create response metadata + response_metadata = TranscriptResponse( + confidence=alternatives[0].get("confidence", 0), + ) # Emit immediately for real-time responsiveness if is_final: - self._emit_transcript_event(transcript_text, self._current_user, metadata) + self._emit_transcript_event(transcript_text, self._current_user, response_metadata) else: self._emit_partial_transcript_event( - transcript_text, self._current_user, metadata + transcript_text, self._current_user, response_metadata ) logger.debug( @@ -246,7 +236,7 @@ async def _on_message( extra={ "is_final": is_final, "text_length": len(transcript_text), - "confidence": metadata["confidence"], + "confidence": response_metadata.confidence, }, ) @@ -261,29 +251,15 @@ async def _on_connection_close(self, message: Any): logger.warning(f"Deepgram connection closed. message={message}") await self.close() - async def _process_audio_impl( + async def process_audio( self, pcm_data: PcmData, - user_metadata: Optional[Union[Dict[str, Any], "Participant"]] = None, - ) -> Optional[List[Tuple[bool, str, Dict[str, Any]]]]: - """ - Process audio data through Deepgram for transcription. - - Args: - pcm_data: The PCM audio data to process. - user_metadata: Additional metadata about the user or session. - - Returns: - None - Deepgram operates in asynchronous mode and emits events directly - when transcripts arrive from the streaming service. - """ - if self._is_closed: + participant: Optional[Participant] = None, + ): + if self.closed: logger.warning("Deepgram connection is closed, ignoring audio") return None - # Store the current user context for transcript events - self._current_user = user_metadata # type: ignore[assignment] - # Check if the input sample rate matches the expected sample rate if pcm_data.sample_rate != self.sample_rate: logger.warning( @@ -334,7 +310,7 @@ async def _keepalive_loop(self): Send the silence audio every `interval` seconds to prevent Deepgram from closing the connection. """ - while not self._is_closed and self.dg_connection is not None: + while not self.closed and self.dg_connection is not None: if self._last_sent_at + self._keep_alive_interval <= time.time(): logger.debug("Sending keepalive packet to Deepgram...") # Send audio silence to keep the connection open diff --git a/plugins/fish/tests/test_fish_stt.py b/plugins/fish/tests/test_fish_stt.py index 69deff99..8eb092ff 100644 --- a/plugins/fish/tests/test_fish_stt.py +++ b/plugins/fish/tests/test_fish_stt.py @@ -32,9 +32,8 @@ async def test_transcribe_mia_audio(self, stt, mia_audio_16khz): assert not session.errors # Verify transcript - assert len(session.transcripts) > 0, "Expected at least one transcript" - transcript_event = session.transcripts[0] - assert "forgotten treasures" in transcript_event.text.lower() + full_transcript = session.get_full_transcript() + assert "forgotten treasures" in full_transcript.lower() @pytest.mark.integration async def test_transcribe_mia_audio_48khz(self, stt, mia_audio_48khz): @@ -49,6 +48,5 @@ async def test_transcribe_mia_audio_48khz(self, stt, mia_audio_48khz): assert not session.errors # Verify transcript - assert len(session.transcripts) > 0, "Expected at least one transcript" - transcript_event = session.transcripts[0] - assert "forgotten treasures" in transcript_event.text.lower() + full_transcript = session.get_full_transcript() + assert "forgotten treasures" in full_transcript.lower() diff --git a/plugins/fish/vision_agents/plugins/fish/stt.py b/plugins/fish/vision_agents/plugins/fish/stt.py index f6e300e0..fcfff0d4 100644 --- a/plugins/fish/vision_agents/plugins/fish/stt.py +++ b/plugins/fish/vision_agents/plugins/fish/stt.py @@ -2,16 +2,16 @@ import logging import os import wave -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import Optional import numpy as np from fish_audio_sdk import Session, ASRRequest from getstream.video.rtc.track_util import PcmData from vision_agents.core import stt +from vision_agents.core.stt import TranscriptResponse -if TYPE_CHECKING: - from vision_agents.core.edge.types import Participant +from vision_agents.core.edge.types import Participant logger = logging.getLogger(__name__) @@ -37,39 +37,19 @@ def __init__( self, api_key: Optional[str] = None, language: Optional[str] = None, - ignore_timestamps: bool = False, - sample_rate: int = 16000, - base_url: Optional[str] = None, client: Optional[Session] = None, ): - """ - Initialize the Fish Audio STT service. - - Args: - api_key: Fish Audio API key. If not provided, the FISH_API_KEY - environment variable will be used. - language: Language code for transcription (e.g., "en", "zh"). If None, - automatic language detection will be used. - ignore_timestamps: Skip timestamp processing for faster results. - sample_rate: Sample rate of the audio in Hz (default: 16000). - base_url: Optional custom API endpoint. - client: Optionally pass in your own instance of the Fish Audio Session. - """ - super().__init__(sample_rate=sample_rate, provider_name="fish") + super().__init__(provider_name="fish") if not api_key: api_key = os.environ.get("FISH_API_KEY") if client is not None: self.client = client - elif base_url: - self.client = Session(api_key, base_url=base_url) else: self.client = Session(api_key) self.language = language - self.ignore_timestamps = ignore_timestamps - self._current_user: Optional[Union[Dict[str, Any], "Participant"]] = None @staticmethod def _pcm_to_wav_bytes(pcm_data: PcmData) -> bytes: @@ -88,11 +68,11 @@ def _pcm_to_wav_bytes(pcm_data: PcmData) -> bytes: return wav_buffer.getvalue() - async def _process_audio_impl( + async def process_audio( self, pcm_data: PcmData, - user_metadata: Optional[Union[Dict[str, Any], "Participant"]] = None, - ) -> Optional[List[Tuple[bool, str, Dict[str, Any]]]]: + participant: Optional[Participant] = None, + ): """ Process audio data through Fish Audio for transcription. @@ -107,14 +87,10 @@ async def _process_audio_impl( List of tuples (is_final, text, metadata) representing transcription results, or None if no results are available. Fish Audio returns final results only. """ - if self._is_closed: + if self.closed: logger.warning("Fish Audio STT is closed, ignoring audio") return None - - # Store the current user context - self._current_user = user_metadata - # Check if we have valid audio data if not hasattr(pcm_data, "samples") or pcm_data.samples is None: logger.warning("No audio samples to process") @@ -133,7 +109,7 @@ async def _process_audio_impl( asr_request = ASRRequest( audio=wav_data, language=self.language, - ignore_timestamps=self.ignore_timestamps, + ignore_timestamps=True, ) # Send to Fish Audio API @@ -150,23 +126,12 @@ async def _process_audio_impl( logger.error("No transcript returned from Fish Audio %s", pcm_data.duration) return None - # Build metadata from response - metadata: Dict[str, Any] = { - "audio_duration_ms": response.duration, - "language": self.language or "auto", - "model_name": "fish-audio-asr", - } - - # Include segments if timestamps were requested - if not self.ignore_timestamps and response.segments: - metadata["segments"] = [ - { - "text": segment.text, - "start": segment.start, - "end": segment.end, - } - for segment in response.segments - ] + # Build response metadata + response_metadata = TranscriptResponse( + audio_duration_ms=response.duration, + language=self.language or "auto", + model_name="fish-audio-asr", + ) logger.debug( "Received transcript from Fish Audio", @@ -176,8 +141,7 @@ async def _process_audio_impl( }, ) - # Return as final result (Fish Audio doesn't support streaming/partial results) - return [(True, transcript_text, metadata)] + self._emit_transcript_event(transcript_text, participant, response_metadata) except Exception as e: logger.error( @@ -187,12 +151,3 @@ async def _process_audio_impl( # Let the base class handle error emission raise - async def close(self): - """Close the Fish Audio STT service and clean up resources.""" - if self._is_closed: - logger.debug("Fish Audio STT service already closed") - return - - logger.info("Closing Fish Audio STT service") - await super().close() - diff --git a/plugins/moonshine/tests/test_moonshine_stt.py b/plugins/moonshine/tests/test_moonshine_stt.py new file mode 100644 index 00000000..280dee96 --- /dev/null +++ b/plugins/moonshine/tests/test_moonshine_stt.py @@ -0,0 +1,50 @@ +import pytest + +from vision_agents.plugins import moonshine +from conftest import STTSession + + +class TestMoonshineSTT: + """Integration tests for Moonshine STT""" + + @pytest.fixture + async def stt(self): + """Create and manage Moonshine STT lifecycle""" + stt = moonshine.STT() + try: + yield stt + finally: + await stt.close() + + @pytest.mark.integration + async def test_transcribe_mia_audio(self, stt, mia_audio_16khz): + # Create session to collect transcripts and errors + session = STTSession(stt) + + # Process the audio + await stt.process_audio(mia_audio_16khz) + + # Wait for result + await session.wait_for_result(timeout=30.0) + assert not session.errors + + # Verify transcript + full_transcript = session.get_full_transcript() + assert "forgotten treasures" in full_transcript.lower() + + @pytest.mark.integration + async def test_transcribe_mia_audio_48khz(self, stt, mia_audio_48khz): + # Create session to collect transcripts and errors + session = STTSession(stt) + + # Process the audio + await stt.process_audio(mia_audio_48khz) + + # Wait for result + await session.wait_for_result(timeout=30.0) + assert not session.errors + + # Verify transcript + full_transcript = session.get_full_transcript() + assert "forgotten treasures" in full_transcript.lower() + diff --git a/plugins/moonshine/tests/test_stt.py b/plugins/moonshine/tests/test_stt.py deleted file mode 100644 index 711b5c54..00000000 --- a/plugins/moonshine/tests/test_stt.py +++ /dev/null @@ -1,774 +0,0 @@ -import pytest -import asyncio -import numpy as np -from unittest.mock import patch - -from vision_agents.plugins import moonshine -from getstream.video.rtc.track_util import PcmData -from plugins.plugin_test_utils import get_audio_asset, get_json_metadata - -# Skip all tests in this module if moonshine_onnx is not installed -try: - import moonshine_onnx # noqa: F401 -except ImportError: - pytest.skip( - "moonshine_onnx is not installed. Skipping all Moonshine STT tests.", - allow_module_level=True, - ) - - -# Mock moonshine module for tests that don't require the actual library -class MockMoonshine: - @staticmethod - def transcribe(audio_path, model_name): - """Mock transcribe function that returns a simple result.""" - # Simulate different responses based on model name - if "base" in model_name: - return ["This is a high quality transcription from the base model."] - else: - return ["This is a transcription from the tiny model."] - - -@pytest.fixture -def mia_mp3_path(): - """Return the path to the mia.mp3 test file.""" - return get_audio_asset("mia.mp3") - - -@pytest.fixture -def mia_json_path(): - """Return the path to the mia.json metadata file.""" - return get_audio_asset("mia.json") - - -@pytest.fixture -def mia_metadata(): - """Load the mia.json metadata.""" - return get_json_metadata("mia.json") - - -@pytest.fixture -def mia_audio_data(mia_mp3_path): - """Load and prepare the mia.mp3 audio data for testing.""" - try: - # Try to load the mp3 file using soundfile - import soundfile as sf - - data, original_sample_rate = sf.read(mia_mp3_path) - - # Convert to mono if stereo - if len(data.shape) > 1: - data = np.mean(data, axis=1) - - # Resample to 16kHz (Moonshine's native rate) - target_sample_rate = 16000 - if original_sample_rate != target_sample_rate: - from getstream.audio.utils import resample_audio - - data = resample_audio(data, original_sample_rate, target_sample_rate) - - # Normalize and convert to int16 - if data.max() > 1.0 or data.min() < -1.0: - data = data / max(abs(data.max()), abs(data.min())) - - # Convert to int16 PCM - pcm_samples = (data * 32767.0).astype(np.int16) - - # Return PCM data with the resampled rate - return PcmData( - samples=pcm_samples, sample_rate=target_sample_rate, format="s16" - ) - except Exception: - # Fall back to synthetic data if file loading fails - sample_rate = 16000 - duration_sec = 2 - t = np.linspace(0, duration_sec, int(duration_sec * sample_rate)) - - # Create speech-like signal with multiple formants - audio_data = np.zeros_like(t) - for formant, amplitude in [(600, 1.0), (1200, 0.5), (2400, 0.2)]: - audio_data += amplitude * np.sin(2 * np.pi * formant * t) - - # Normalize and convert to int16 - audio_data = audio_data / np.max(np.abs(audio_data)) - pcm_samples = (audio_data * 32767.0).astype(np.int16) - - return PcmData(samples=pcm_samples, sample_rate=sample_rate, format="s16") - - -@pytest.fixture -def audio_data_16k(): - """Load and prepare 16kHz audio data for testing.""" - try: - # Try to load real audio asset - audio_path = get_audio_asset("formant_speech_16k.wav") - import soundfile as sf - - data, sample_rate = sf.read(audio_path) - - # Convert to mono if stereo - if len(data.shape) > 1: - data = np.mean(data, axis=1) - - # Convert to int16 - pcm_samples = (data * 32767.0).astype(np.int16) - - return PcmData(samples=pcm_samples, sample_rate=sample_rate, format="s16") - except Exception: - # Fall back to synthetic data - sample_rate = 16000 - duration_sec = 2 - t = np.linspace(0, duration_sec, int(duration_sec * sample_rate)) - - # Create speech-like signal with multiple formants - audio_data = np.zeros_like(t) - for formant, amplitude in [(600, 1.0), (1200, 0.5), (2400, 0.2)]: - audio_data += amplitude * np.sin(2 * np.pi * formant * t) - - # Normalize and convert to int16 - audio_data = audio_data / np.max(np.abs(audio_data)) - pcm_samples = (audio_data * 32767.0).astype(np.int16) - - return PcmData(samples=pcm_samples, sample_rate=sample_rate, format="s16") - - -@pytest.fixture -def audio_data_48k(): - """Load and prepare 48kHz audio data for testing.""" - try: - # Try to load real audio asset - audio_path = get_audio_asset("formant_speech_48k.wav") - import soundfile as sf - - data, sample_rate = sf.read(audio_path) - - # Convert to mono if stereo - if len(data.shape) > 1: - data = np.mean(data, axis=1) - - # Convert to int16 - pcm_samples = (data * 32767.0).astype(np.int16) - - return PcmData(samples=pcm_samples, sample_rate=sample_rate, format="s16") - except Exception: - # Fall back to synthetic data - sample_rate = 48000 - duration_sec = 2 - t = np.linspace(0, duration_sec, int(duration_sec * sample_rate)) - - # Create speech-like signal - audio_data = np.zeros_like(t) - for formant, amplitude in [(600, 1.0), (1200, 0.5), (2400, 0.2)]: - audio_data += amplitude * np.sin(2 * np.pi * formant * t) - - # Normalize and convert to int16 - audio_data = audio_data / np.max(np.abs(audio_data)) - pcm_samples = (audio_data * 32767.0).astype(np.int16) - - return PcmData(samples=pcm_samples, sample_rate=sample_rate, format="s16") - - -@pytest.mark.asyncio -async def test_moonshine_model_validation(): - """Test that Moonshine validates model names correctly.""" - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - # Test invalid model name - with pytest.raises(ValueError, match="Unknown Moonshine model"): - moonshine.STT(model_name="invalid_model") - - # Test valid model names - stt1 = moonshine.STT(model_name="tiny") - assert stt1.model_name == "moonshine/tiny" - await stt1.close() - - stt2 = moonshine.STT(model_name="moonshine/base") - assert stt2.model_name == "moonshine/base" - await stt2.close() - - -@pytest.mark.asyncio -async def test_moonshine_initialization(): - """Test that Moonshine initializes correctly with mocked library.""" - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - stt = moonshine.STT() - assert stt is not None - assert stt.model_name == "moonshine/base" # Canonical value after validation - assert stt.sample_rate == 16000 - await stt.close() - - -@pytest.mark.asyncio -async def test_moonshine_custom_initialization(): - """Test Moonshine initialization with custom parameters.""" - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - stt = moonshine.STT( - model_name="moonshine/base", - sample_rate=16000, - min_audio_length_ms=1000, - ) - - assert stt.model_name == "moonshine/base" # Canonical value after validation - assert stt.sample_rate == 16000 - assert stt.min_audio_length_ms == 1000 - await stt.close() - - -@pytest.mark.asyncio -async def test_moonshine_audio_resampling(): - """Test that audio resampling works correctly.""" - from getstream.audio.utils import resample_audio - - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - stt = moonshine.STT(sample_rate=16000) - - # Test resampling from 48kHz to 16kHz using the shared utility - original_data = np.random.randint( - -1000, 1000, 48000, dtype=np.int16 - ) # 1 second at 48kHz - resampled = resample_audio(original_data, 48000, 16000).astype(np.int16) - - # Should be approximately 16000 samples (1 second at 16kHz) - assert abs(len(resampled) - 16000) < 100 # Allow some tolerance - assert resampled.dtype == np.int16 - - await stt.close() - - -@pytest.mark.asyncio -async def test_moonshine_audio_normalization(): - """Test that audio normalization works correctly.""" - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - stt = moonshine.STT() - - # Test normalization - int16_data = np.array([32767, -32768, 0, 16384], dtype=np.int16) - normalized = stt._normalize_audio(int16_data) - - assert normalized.dtype == np.float32 - assert np.allclose(normalized, [1.0, -1.0, 0.0, 0.5], atol=1e-4) - - await stt.close() - - -@pytest.mark.asyncio -async def test_moonshine_immediate_processing(): - """Test that audio is processed immediately without buffering.""" - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - stt = moonshine.STT(sample_rate=16000) - - # Mock the _transcribe_audio method to track calls - transcribe_calls = [] - - async def mock_transcribe_audio(audio_data): - transcribe_calls.append(len(audio_data)) - return "test transcription" - - stt._transcribe_audio = mock_transcribe_audio - - # Create test audio - audio_array = np.random.randint( - -1000, 1000, 8000, dtype=np.int16 - ) # 0.5 seconds - pcm_data = PcmData(samples=audio_array, sample_rate=16000, format="s16") - - # Process audio - should be immediate, no buffering - await stt.process_audio(pcm_data) - - # Should have called transcribe immediately - assert len(transcribe_calls) == 1 - assert transcribe_calls[0] == 8000 # Same length as input - - await stt.close() - - -@pytest.mark.asyncio -async def test_moonshine_process_audio_short_chunk(audio_data_16k): - """Test processing audio that's too short to trigger transcription.""" - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - stt = moonshine.STT(min_audio_length_ms=1000) # Require 1s minimum - - # Track events - transcripts = [] - errors = [] - - @stt.events.subscribe - def on_transcript(text, user, metadata): - transcripts.append((text, user, metadata)) - - @stt.events.subscribe - def on_error(error): - errors.append(error) - - # Create short audio (0.5 seconds) - short_audio = PcmData( - samples=audio_data_16k.samples[:8000], # 0.5 seconds at 16kHz - sample_rate=16000, - format="s16", - ) - - # Process the short audio - await stt.process_audio(short_audio) - - # Should not trigger transcription due to min_audio_length_ms - assert len(transcripts) == 0 - assert len(errors) == 0 - - await stt.close() - - -@pytest.mark.asyncio -async def test_moonshine_process_audio_sufficient_chunk(audio_data_16k): - """Test processing audio that's long enough to trigger transcription.""" - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - # Mock the _transcribe_audio method directly instead of moonshine.transcribe - stt = moonshine.STT(min_audio_length_ms=500) - - # Track events - transcripts = [] - errors = [] - - @stt.events.subscribe - def on_transcript(text, user, metadata): - transcripts.append((text, user, metadata)) - - @stt.events.subscribe - def on_error(error): - errors.append(error) - - # Mock the _transcribe_audio method to return a test result - async def mock_transcribe_audio(audio_data): - return "This is a test transcription" - - stt._transcribe_audio = mock_transcribe_audio - - # Process sufficient audio - await stt.process_audio(audio_data_16k) - - # Give some time for async processing - await asyncio.sleep(0.1) - - # Should trigger transcription - assert len(transcripts) > 0 - assert len(errors) == 0 - - # Check transcript content - text, user, metadata = transcripts[0] - assert isinstance(text, str) - assert len(text) > 0 - assert "model_name" in metadata - assert "audio_duration_ms" in metadata - - await stt.close() - - -@pytest.mark.asyncio -async def test_moonshine_process_audio_with_resampling(audio_data_48k): - """Test processing audio that requires resampling.""" - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - stt = moonshine.STT(sample_rate=16000) - - # Track events - transcripts = [] - - @stt.events.subscribe - def on_transcript(text, user, metadata): - transcripts.append((text, user, metadata)) - - # Mock the _transcribe_audio method to return a test result - async def mock_transcribe_audio(audio_data): - return "This is a test transcription" - - stt._transcribe_audio = mock_transcribe_audio - - # Process 48kHz audio (should be resampled to 16kHz) - await stt.process_audio(audio_data_48k) - - # Give some time for async processing - await asyncio.sleep(0.1) - - # Should trigger transcription after resampling - assert len(transcripts) > 0 - - await stt.close() - - -@pytest.mark.asyncio -async def test_moonshine_flush_functionality(audio_data_16k): - """Test that flush is a no-op since we no longer buffer.""" - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - stt = moonshine.STT(min_audio_length_ms=500) - - # Track events - transcripts = [] - - @stt.events.subscribe - def on_transcript(text, user, metadata): - transcripts.append((text, user, metadata)) - - # Mock the _transcribe_audio method to return a test result - async def mock_transcribe_audio(audio_data): - return "This is a test transcription" - - stt._transcribe_audio = mock_transcribe_audio - - # Process audio - should trigger immediate transcription - audio = PcmData( - samples=audio_data_16k.samples[:16000], # 1 second - sample_rate=16000, - format="s16", - ) - await stt.process_audio(audio) - - # Give some time for async processing - await asyncio.sleep(0.1) - - # Should have triggered transcription immediately - assert len(transcripts) == 1 - - # Flush should be a no-op - await stt.flush() - - # Should still have only one transcript - assert len(transcripts) == 1 - - await stt.close() - - -@pytest.mark.asyncio -async def test_moonshine_bytes_input(): - """Test processing audio data provided as bytes.""" - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - stt = moonshine.STT() - - # Track events - transcripts = [] - - @stt.events.subscribe - def on_transcript(text, user, metadata): - transcripts.append((text, user, metadata)) - - # Mock the _transcribe_audio method to return a test result - async def mock_transcribe_audio(audio_data): - return "This is a test transcription" - - stt._transcribe_audio = mock_transcribe_audio - - # Create audio as bytes - audio_array = np.random.randint(-1000, 1000, 16000, dtype=np.int16) # 1 second - audio_bytes = audio_array.tobytes() - - pcm_data = PcmData(samples=audio_bytes, sample_rate=16000, format="s16") - await stt.process_audio(pcm_data) - - # Give some time for async processing - await asyncio.sleep(0.1) - - # Should trigger transcription - assert len(transcripts) > 0 - - await stt.close() - - -@pytest.mark.asyncio -async def test_moonshine_error_handling(): - """Test error handling during transcription.""" - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - stt = moonshine.STT() - - # Track events - errors = [] - - @stt.events.subscribe - def on_error(error): - errors.append(error) - - # Mock the _transcribe_audio method to raise an exception - async def mock_transcribe_audio(audio_data): - raise Exception("Transcription failed") - - stt._transcribe_audio = mock_transcribe_audio - - # Create sufficient audio - audio_array = np.random.randint(-1000, 1000, 16000, dtype=np.int16) - pcm_data = PcmData(samples=audio_array, sample_rate=16000, format="s16") - - await stt.process_audio(pcm_data) - - # Give some time for async processing - await asyncio.sleep(0.1) - - # Should have captured the error - assert len(errors) > 0 - - await stt.close() - - -@pytest.mark.asyncio -async def test_moonshine_closed_state(): - """Test that processing is ignored when STT is closed.""" - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - stt = moonshine.STT() - - # Close the STT - await stt.close() - - # Try to process audio - audio_array = np.random.randint(-1000, 1000, 16000, dtype=np.int16) - pcm_data = PcmData(samples=audio_array, sample_rate=16000, format="s16") - - result = await stt._process_audio_impl(pcm_data) - - # Should return None when closed - assert result is None - - -@pytest.mark.asyncio -async def test_moonshine_model_selection(): - """Test that different models produce different results.""" - with patch("vision_agents.plugins.moonshine.stt.moonshine"): - # Test tiny model - stt_tiny = moonshine.STT(model_name="moonshine/tiny") - - # Test base model - stt_base = moonshine.STT(model_name="moonshine/base") - - # Both should initialize successfully with canonical names - assert ( - stt_tiny.model_name == "moonshine/tiny" - ) # Canonical value after validation - assert ( - stt_base.model_name == "moonshine/base" - ) # Canonical value after validation - - await stt_tiny.close() - await stt_base.close() - - -@pytest.mark.asyncio -async def test_moonshine_with_mia_audio_mocked(mia_audio_data, mia_metadata): - """Test Moonshine STT with mia.mp3 audio using mocked transcription.""" - with patch("vision_agents.plugins.moonshine.stt.stt.moonshine") as mock_moonshine: - # Extract expected text from mia.json metadata - expected_segments = mia_metadata.get("segments", []) - expected_full_text = " ".join( - [segment["text"] for segment in expected_segments] - ).strip() - - # Mock the transcribe function to return the expected text - mock_moonshine.transcribe.return_value = [expected_full_text] - - stt = moonshine.STT(model_name="moonshine/base", min_audio_length_ms=500) - - # Track events - transcripts = [] - errors = [] - - @stt.events.subscribe - def on_transcript(text, user, metadata): - transcripts.append((text, user, metadata)) - - @stt.events.subscribe - def on_error(error): - errors.append(error) - - # Process the mia audio data - await stt.process_audio(mia_audio_data) - - # Wait for processing - await asyncio.sleep(0.1) - - # Flush any remaining audio - await stt.flush() - await asyncio.sleep(0.1) - - # Verify results - assert len(errors) == 0, f"Received errors: {errors}" - assert len(transcripts) > 0, "No transcripts received" - - # Check transcript content - text, user, metadata = transcripts[0] - assert isinstance(text, str) - assert len(text) > 0 - assert "model_name" in metadata - assert "audio_duration_ms" in metadata - - # Verify the transcript contains expected content - assert "mia" in text.lower() - assert "village" in text.lower() or "treasure" in text.lower() - - # Verify metadata structure - assert metadata["model_name"] == "moonshine/base" - assert ( - metadata["confidence"] is None - ) # Moonshine doesn't provide confidence scores - assert metadata["target_sample_rate"] == 16000 - assert "processing_time_ms" in metadata - assert "original_sample_rate" in metadata - assert "resampled" in metadata - - # Verify the mock was called correctly - mock_moonshine.transcribe.assert_called() - call_args = mock_moonshine.transcribe.call_args - assert len(call_args[0]) == 2 # audio_path, model_name - assert call_args[0][1] == "moonshine/base" # model_name - - await stt.close() - - -# Integration test with real Moonshine (if available) -@pytest.mark.integration -@pytest.mark.asyncio -async def test_moonshine_real_integration(mia_audio_data, mia_metadata): - """ - Integration test with the real Moonshine library using the mia.mp3 test file. - - This test processes the mia.mp3 audio file and compares the transcription results - with the expected content from mia.json metadata. - - This test will be skipped if Moonshine is not installed. - """ - # Only run if we have a reasonable amount of audio - if len(mia_audio_data.samples) < 8000: # Less than 0.5 seconds - pytest.skip("Audio sample too short for meaningful integration test") - - print( - f"Testing with mia.mp3: {len(mia_audio_data.samples)} samples at {mia_audio_data.sample_rate}Hz" - ) - print( - f"Audio duration: {len(mia_audio_data.samples) / mia_audio_data.sample_rate:.2f} seconds" - ) - print( - f"Audio range: {mia_audio_data.samples.min()} to {mia_audio_data.samples.max()}" - ) - - # Extract expected text from mia.json metadata - expected_segments = mia_metadata.get("segments", []) - expected_full_text = " ".join( - [segment["text"] for segment in expected_segments] - ).strip() - expected_words = expected_full_text.lower().split() - - print(f"Expected transcript: {expected_full_text}") - print(f"Expected word count: {len(expected_words)}") - - stt = moonshine.STT( - model_name="moonshine/tiny", # Use tiny model for faster testing - min_audio_length_ms=500, - ) - - # Track events - transcripts = [] - errors = [] - - @stt.events.subscribe - def on_transcript(text, user, metadata): - transcripts.append((text, user, metadata)) - - @stt.events.subscribe - def on_error(error): - errors.append(error) - - try: - # Process the audio in chunks to simulate real-time processing - chunk_size = 8000 # Process in 0.5 second chunks at 16kHz - total_samples = len(mia_audio_data.samples) - - for i in range(0, total_samples, chunk_size): - end_idx = min(i + chunk_size, total_samples) - chunk_samples = mia_audio_data.samples[i:end_idx] - - chunk_data = PcmData( - samples=chunk_samples, - sample_rate=mia_audio_data.sample_rate, - format=mia_audio_data.format, - ) - - await stt.process_audio(chunk_data) - await asyncio.sleep(0.1) # Small delay between chunks - - # Wait for processing to complete - await asyncio.sleep(2.0) - - # Flush any remaining audio - await stt.flush() - await asyncio.sleep(1.0) - - # Check results - print(f"Transcripts received: {len(transcripts)}") - print(f"Errors received: {len(errors)}") - - if transcripts: - for i, (text, user, metadata) in enumerate(transcripts): - print(f"Transcript {i + 1}: {text}") - print(f"Metadata: {metadata}") - - if errors: - for i, error in enumerate(errors): - print(f"Error {i + 1}: {error}") - - # We should either get transcripts or errors, but not silence - assert len(transcripts) > 0 or len(errors) > 0, ( - "No transcripts or errors received" - ) - - # If we got transcripts, verify they contain reasonable content - if transcripts: - # Combine all transcript text - combined_text = " ".join([t[0] for t in transcripts]).strip() - actual_words = combined_text.lower().split() - - print(f"Combined transcript: {combined_text}") - print(f"Actual word count: {len(actual_words)}") - - # Basic validation - text, user, metadata = transcripts[0] - assert isinstance(text, str) - assert len(text.strip()) > 0 - assert "model_name" in metadata - assert metadata["model_name"] == "moonshine/tiny" - assert "audio_duration_ms" in metadata - assert metadata["audio_duration_ms"] > 0 - - # Content validation - check for key words from the expected transcript - # We'll be lenient since STT accuracy can vary - key_words = [ - "mia", - "village", - "brushes", - "map", - "treasure", - "fields", - "hues", - "discovered", - ] - found_key_words = [ - word for word in key_words if word in combined_text.lower() - ] - - print(f"Key words found: {found_key_words}") - - # We should find at least some key words from the story - assert len(found_key_words) >= 2, ( - f"Expected to find at least 2 key words from {key_words}, but only found {found_key_words}" - ) - - # Check that we got a reasonable amount of text - assert len(actual_words) >= 10, ( - f"Expected at least 10 words, but got {len(actual_words)}: {combined_text}" - ) - - # Verify metadata structure - assert "processing_time_ms" in metadata - assert "confidence" in metadata - assert ( - metadata["confidence"] is None - ) # Moonshine doesn't provide confidence scores - assert "original_sample_rate" in metadata - assert "target_sample_rate" in metadata - assert metadata["target_sample_rate"] == 16000 # Moonshine's native rate - - # We shouldn't have any errors - assert len(errors) == 0, f"Received errors: {errors}" - - finally: - await stt.close() diff --git a/plugins/moonshine/vision_agents/plugins/moonshine/stt.py b/plugins/moonshine/vision_agents/plugins/moonshine/stt.py index afe0b25c..4b526e44 100644 --- a/plugins/moonshine/vision_agents/plugins/moonshine/stt.py +++ b/plugins/moonshine/vision_agents/plugins/moonshine/stt.py @@ -1,21 +1,23 @@ -import os import logging +import os import time -from typing import Dict, Any, Optional, Tuple, List, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple -if TYPE_CHECKING: - from vision_agents.core.edge.types import Participant import numpy as np import soundfile as sf - -from vision_agents.core import stt -from getstream.video.rtc.track_util import PcmData -from getstream.audio.utils import resample_audio from getstream.audio.pcm_utils import ( + log_audio_processing_info, pcm_to_numpy_array, validate_sample_rate_compatibility, - log_audio_processing_info, ) +from getstream.audio.utils import resample_audio +from getstream.video.rtc.track_util import PcmData + +from vision_agents.core import stt +from vision_agents.core.stt import TranscriptResponse + +if TYPE_CHECKING: + from vision_agents.core.edge.types import Participant logger = logging.getLogger(__name__) @@ -59,7 +61,6 @@ class STT(stt.STT): def __init__( self, model_name: str = "moonshine/base", - sample_rate: int = 16000, min_audio_length_ms: int = 100, target_dbfs: float = -26.0, ): @@ -68,11 +69,13 @@ def __init__( Args: model_name: Moonshine model to use ("moonshine/tiny" or "moonshine/base") - sample_rate: Sample rate of the audio in Hz (default: 16000, Moonshine's native rate) min_audio_length_ms: Minimum audio length required for transcription target_dbfs: Target RMS level in dBFS for audio normalization (default: -26.0, Moonshine's optimal level) """ - super().__init__(sample_rate=sample_rate) + super().__init__(provider_name="moonshine") + + # Moonshine's native sample rate + self.sample_rate = 16000 # Check if moonshine_onnx is available if not MOONSHINE_AVAILABLE: @@ -91,17 +94,12 @@ def __init__( self.min_audio_length_ms = min_audio_length_ms self.target_dbfs = target_dbfs - # Track current user context - self._current_user: Optional[Dict[str, Any]] = None - # Local explicit state flags for mypy visibility - self._is_closed: bool = False - logger.info( "Initialized Moonshine STT", extra={ "model_name": model_name, "canonical_model": self.model_name, - "sample_rate": sample_rate, + "sample_rate": self.sample_rate, "min_audio_length_ms": min_audio_length_ms, "target_dbfs": target_dbfs, }, @@ -262,30 +260,21 @@ async def _transcribe_audio( logger.error("Error during transcription", exc_info=e) return None - async def _process_audio_impl( - self, pcm_data: PcmData, user_metadata: Optional[Union[Dict[str, Any], "Participant"]] = None - ) -> Optional[List[Tuple[bool, str, Dict[str, Any]]]]: + async def process_audio( + self, + pcm_data: PcmData, + participant: Optional["Participant"] = None, + ): """ Process audio data through Moonshine for transcription. - Moonshine operates in synchronous mode - it processes audio immediately and - returns results to the base class for event emission. - Args: - pcm_data: The PCM audio data to process. - user_metadata: Additional metadata about the user or session. - - Returns: - List of tuples (is_final, text, metadata) representing transcription results, - or None if no results are available. Moonshine returns final results only - since it doesn't support streaming transcription. + pcm_data: The PCM audio data to process + participant: Optional participant metadata """ - if self._is_closed: + if self.closed: logger.warning("Moonshine STT is closed, ignoring audio") - return None - - # Store the current user context - self._current_user = user_metadata # type: ignore[assignment] + return try: # Log incoming audio details for debugging using shared utility @@ -318,44 +307,25 @@ async def _process_audio_impl( text = result processing_time_ms = 0.0 # Default for mocked tests - # Create metadata - metadata = { - "model_name": self.model_name, - "audio_duration_ms": (len(audio_array) / self.sample_rate) * 1000, - "processing_time_ms": processing_time_ms, - # Moonshine doesn't provide confidence scores, so we don't set it - "original_sample_rate": pcm_data.sample_rate, - "target_sample_rate": self.sample_rate, - "resampled": pcm_data.sample_rate != self.sample_rate, - } - - # Return as final transcript (Moonshine doesn't support streaming) - return [(True, text, metadata)] + # Create response metadata + response_metadata = TranscriptResponse( + model_name=self.model_name, + audio_duration_ms=(len(audio_array) / self.sample_rate) * 1000, + processing_time_ms=processing_time_ms, + ) - return None + # Emit transcript event + self._emit_transcript_event(text, participant, response_metadata) except Exception as e: # Use the base class helper for consistent error handling self._emit_error_event(e, "Moonshine audio processing") - return None - - async def flush(self): - """ - Flush any remaining audio in the buffer and process it. - - Note: This is a no-op since we no longer buffer audio. - """ - # No buffering, so nothing to flush - pass async def close(self): """Close the Moonshine STT service and clean up resources.""" - if self._is_closed: + if self.closed: logger.debug("Moonshine STT service already closed") return logger.info("Closing Moonshine STT service") - self._is_closed = True - - # No buffers to clear since we don't buffer anymore - logger.debug("Moonshine STT service closed successfully") + await super().close() diff --git a/plugins/ultralytics/vision_agents/plugins/ultralytics/yolo_pose_processor.py b/plugins/ultralytics/vision_agents/plugins/ultralytics/yolo_pose_processor.py index 761695ce..0adbbd06 100644 --- a/plugins/ultralytics/vision_agents/plugins/ultralytics/yolo_pose_processor.py +++ b/plugins/ultralytics/vision_agents/plugins/ultralytics/yolo_pose_processor.py @@ -96,15 +96,10 @@ async def recv(self) -> av.frame.Frame: pts, time_base = await self.next_timestamp() # Create av.VideoFrame from PIL Image - try: - av_frame = self.last_frame - - av_frame.pts = pts - av_frame.time_base = time_base - except Exception: - import pdb + av_frame = self.last_frame - pdb.set_trace() + av_frame.pts = pts + av_frame.time_base = time_base # if frame_received: # logger.info(f"Returning NEW video frame: {av_frame.width}x{av_frame.height}") diff --git a/plugins/wizper/tests/test_stt.py b/plugins/wizper/tests/test_stt.py deleted file mode 100644 index f8f0e871..00000000 --- a/plugins/wizper/tests/test_stt.py +++ /dev/null @@ -1,264 +0,0 @@ -"""Tests for the fal.STT plugin.""" - -import asyncio -import pytest -from unittest.mock import AsyncMock, MagicMock, patch -import numpy as np - -from vision_agents.plugins import wizper -from vision_agents.core.stt.events import STTTranscriptEvent, STTErrorEvent -from getstream.video.rtc.track_util import PcmData - - -@pytest.fixture -async def stt(): - """Provides a fal.STT instance with a mocked fal_client.""" - with patch("fal_client.AsyncClient") as mock_fal_client: - stt_instance = wizper.STT() - stt_instance._fal_client = mock_fal_client.return_value - yield stt_instance - - -class TestfalSTT: - """Test suite for the fal.STT class.""" - - @pytest.mark.asyncio - async def test_init(self): - """Test that the __init__ method sets attributes correctly.""" - stt = wizper.STT(task="translate", target_language="es", sample_rate=16000) - assert stt.task == "translate" - assert stt.target_language == "es" - assert stt.sample_rate == 16000 - assert not stt._is_closed - - def test_pcm_to_wav_bytes(self, stt): - """Test the conversion of PCM data to WAV bytes.""" - samples = (np.sin(np.linspace(0, 440 * 2 * np.pi, 480)) * 63).astype(np.int16) - pcm_data = PcmData( - samples=samples, - sample_rate=48000, - format="s16", - ) - wav_bytes = stt._pcm_to_wav_bytes(pcm_data) - assert wav_bytes.startswith(b"RIFF") - assert b"WAVE" in wav_bytes - - @pytest.mark.asyncio - async def test_process_audio_impl_success_transcribe(self, stt): - """Test successful transcription with a valid response.""" - stt._fal_client.upload_file = AsyncMock( - return_value="http://mock.url/audio.wav" - ) - stt._fal_client.subscribe = AsyncMock( - return_value={"text": " This is a test. ", "chunks": []} - ) - - transcript_handler = AsyncMock() - - @stt.events.subscribe - async def on_transcript(event: STTTranscriptEvent): - await transcript_handler(event) - - # Allow event subscription to be processed - await asyncio.sleep(0.01) - - samples = (np.sin(np.linspace(0, 440 * 2 * np.pi, 480)) * 8191).astype(np.int16) - pcm_data = PcmData( - samples=samples, - sample_rate=48000, - format="s16", - ) - - with ( - patch( - "tempfile.NamedTemporaryFile", new_callable=MagicMock - ) as mock_temp_file, - patch("os.unlink", new_callable=MagicMock) as mock_unlink, - ): - await stt._process_audio_impl(pcm_data, {"user": "test_user"}) - # Allow event loop to process the event emission - await asyncio.sleep(0.01) - - mock_temp_file.assert_called_once_with(suffix=".wav", delete=False) - stt._fal_client.upload_file.assert_awaited_once() - stt._fal_client.subscribe.assert_awaited_once() - - subscribe_args = stt._fal_client.subscribe.call_args.kwargs - assert ( - subscribe_args["arguments"]["audio_url"] == "http://mock.url/audio.wav" - ) - assert subscribe_args["arguments"]["task"] == "transcribe" - assert "language" not in subscribe_args["arguments"] - - # Check that the transcript handler was called with an event object - transcript_handler.assert_called_once() - call_args = transcript_handler.call_args[0] - assert len(call_args) == 1 - event = call_args[0] - assert event.text == "This is a test." - assert event.user_metadata == {"user": "test_user"} - mock_unlink.assert_called_once() - - @pytest.mark.asyncio - async def test_process_audio_impl_success_translate(self): - """Test successful translation with target language.""" - with patch("fal_client.AsyncClient"): - stt = wizper.STT(task="translate", target_language="pt") - stt._fal_client.upload_file = AsyncMock( - return_value="http://mock.url/audio.wav" - ) - stt._fal_client.subscribe = AsyncMock( - return_value={"text": "This is a test.", "chunks": []} - ) - - samples = (np.sin(np.linspace(0, 440 * 2 * np.pi, 480)) * 32767).astype( - np.int16 - ) - pcm_data = PcmData( - samples=samples, - sample_rate=48000, - format="s16", - ) - with patch("tempfile.NamedTemporaryFile"), patch("os.unlink"): - await stt._process_audio_impl(pcm_data) - - subscribe_args = stt._fal_client.subscribe.call_args.kwargs - assert subscribe_args["arguments"]["language"] == "pt" - - @pytest.mark.asyncio - async def test_process_audio_impl_no_text(self, stt): - """Test that no transcript is emitted if the API response lacks 'text'.""" - stt._fal_client.upload_file = AsyncMock( - return_value="http://mock.url/audio.wav" - ) - stt._fal_client.subscribe = AsyncMock( - return_value={"chunks": []} - ) # No 'text' field - - transcript_handler = AsyncMock() - - @stt.events.subscribe - async def on_transcript(event: STTTranscriptEvent): - await transcript_handler(event) - - # Allow event subscription to be processed - await asyncio.sleep(0.01) - - samples = (np.sin(np.linspace(0, 440 * 2 * np.pi, 480)) * 32767).astype( - np.int16 - ) - pcm_data = PcmData( - samples=samples, - sample_rate=48000, - format="s16", - ) - - with patch("tempfile.NamedTemporaryFile"), patch("os.unlink"): - await stt._process_audio_impl(pcm_data) - - transcript_handler.assert_not_called() - - @pytest.mark.asyncio - async def test_process_audio_impl_empty_text(self, stt): - """Test that no transcript is emitted for empty or whitespace-only text.""" - stt._fal_client.upload_file = AsyncMock( - return_value="http://mock.url/audio.wav" - ) - stt._fal_client.subscribe = AsyncMock( - return_value={"text": " ", "chunks": []} - ) # Empty text - - transcript_handler = AsyncMock() - - @stt.events.subscribe - async def on_transcript(event: STTTranscriptEvent): - await transcript_handler(event) - - # Allow event subscription to be processed - await asyncio.sleep(0.01) - - samples = (np.sin(np.linspace(0, 440 * 2 * np.pi, 480)) * 32767).astype( - np.int16 - ) - pcm_data = PcmData( - samples=samples, - sample_rate=48000, - format="s16", - ) - - with patch("tempfile.NamedTemporaryFile"), patch("os.unlink"): - await stt._process_audio_impl(pcm_data) - - transcript_handler.assert_not_called() - - @pytest.mark.asyncio - async def test_process_audio_impl_api_error(self, stt): - """Test that an error event is emitted when the API call fails.""" - stt._fal_client.upload_file = AsyncMock(side_effect=Exception("API Error")) - - error_handler = AsyncMock() - - @stt.events.subscribe - async def on_error(event: STTErrorEvent): - await error_handler(event) - - # Allow event subscription to be processed - await asyncio.sleep(0.01) - - samples = (np.sin(np.linspace(0, 440 * 2 * np.pi, 480)) * 32767).astype( - np.int16 - ) - pcm_data = PcmData( - samples=samples, - sample_rate=48000, - format="s16", - ) - - with patch("tempfile.NamedTemporaryFile"), patch("os.unlink"): - await stt._process_audio_impl(pcm_data) - # Allow event loop to process the event emission - await asyncio.sleep(0.01) - - # Check that the error handler was called with an event object - error_handler.assert_called_once() - call_args = error_handler.call_args[0] - assert len(call_args) == 1 - event = call_args[0] - assert isinstance(event.error, Exception) - assert str(event.error) == "API Error" - - @pytest.mark.asyncio - async def test_process_audio_impl_empty_audio(self, stt): - """Test that no API call is made for empty audio data.""" - pcm_data = PcmData( - samples=np.array([], dtype=np.int16), - sample_rate=48000, - format="s16", - ) - await stt._process_audio_impl(pcm_data) - stt._fal_client.upload_file.assert_not_called() - - @pytest.mark.asyncio - async def test_process_audio_impl_when_closed(self, stt): - """Test that audio is ignored if the STT service is closed.""" - await stt.close() - samples = (np.sin(np.linspace(0, 440 * 2 * np.pi, 480)) * 32767).astype( - np.int16 - ) - pcm_data = PcmData( - samples=samples, - sample_rate=48000, - format="s16", - ) - await stt._process_audio_impl(pcm_data) - stt._fal_client.upload_file.assert_not_called() - - @pytest.mark.asyncio - async def test_close(self, stt): - """Test that the close method works correctly and is idempotent.""" - assert not stt._is_closed - await stt.close() - assert stt._is_closed - # Test idempotency - await stt.close() - assert stt._is_closed diff --git a/plugins/wizper/tests/test_wizper_stt.py b/plugins/wizper/tests/test_wizper_stt.py new file mode 100644 index 00000000..aa618436 --- /dev/null +++ b/plugins/wizper/tests/test_wizper_stt.py @@ -0,0 +1,38 @@ +import pytest +from dotenv import load_dotenv + +from vision_agents.plugins import wizper +from conftest import STTSession + +# Load environment variables +load_dotenv() + + +class TestWizperSTT: + """Integration tests for Wizper STT""" + + @pytest.fixture + async def stt(self): + """Create and manage Wizper STT lifecycle""" + stt = wizper.STT() + try: + yield stt + finally: + await stt.close() + + @pytest.mark.integration + async def test_transcribe_mia_audio_48khz(self, stt, mia_audio_48khz): + # Create session to collect transcripts and errors + session = STTSession(stt) + + # Process the audio + await stt.process_audio(mia_audio_48khz) + + # Wait for result + await session.wait_for_result(timeout=30.0) + assert not session.errors + + # Verify transcript + full_transcript = session.get_full_transcript() + assert "forgotten treasures" in full_transcript.lower() + diff --git a/plugins/wizper/vision_agents/plugins/wizper/stt.py b/plugins/wizper/vision_agents/plugins/wizper/stt.py index 60ca6c79..684cef5c 100644 --- a/plugins/wizper/vision_agents/plugins/wizper/stt.py +++ b/plugins/wizper/vision_agents/plugins/wizper/stt.py @@ -24,20 +24,21 @@ async def on_error(error: str): """ import io +import logging import os import tempfile -import time -import logging from pathlib import Path -from typing import Any, Dict, Optional, List, Tuple, Union, TYPE_CHECKING - -if TYPE_CHECKING: - from vision_agents.core.edge.types import Participant +from typing import TYPE_CHECKING, Optional import wave import fal_client from getstream.video.rtc.track_util import PcmData + from vision_agents.core import stt +from vision_agents.core.stt import TranscriptResponse + +if TYPE_CHECKING: + from vision_agents.core.edge.types import Participant logger = logging.getLogger(__name__) @@ -58,23 +59,21 @@ class STT(stt.STT): def __init__( self, task: str = "transcribe", - target_language: str | None = None, - sample_rate: int = 48000, + target_language: Optional[str] = None, client: Optional[fal_client.AsyncClient] = None, ): """ - Initialize FalWizperSTT. + Initialize Wizper STT. Args: task: "transcribe" or "translate" target_language: Target language code (e.g., "pt" for Portuguese) - sample_rate: Sample rate of the audio in Hz. + client: Optional fal_client.AsyncClient instance for testing """ - super().__init__(sample_rate=sample_rate) + super().__init__(provider_name="wizper") self.task = task + self.sample_rate = 48000 self.target_language = target_language - self.last_activity_time = time.time() - self._is_closed = False self._fal_client = client if client is not None else fal_client.AsyncClient() def _pcm_to_wav_bytes(self, pcm_data: PcmData) -> bytes: @@ -98,26 +97,25 @@ def _pcm_to_wav_bytes(self, pcm_data: PcmData) -> bytes: wav_buffer.seek(0) return wav_buffer.read() - async def _process_audio_impl( - self, pcm_data: PcmData, user_metadata: Optional[Union[Dict[str, Any], "Participant"]] = None - ) -> Optional[List[Tuple[bool, str, Dict[str, Any]]]]: + async def process_audio( + self, + pcm_data: PcmData, + participant: Optional["Participant"] = None, + ): """ - Process accumulated speech audio through fal-ai/wizper. - - This method is typically called by VAD (Voice Activity Detection) systems - when speech segments are detected. + Process audio through fal-ai/wizper for transcription. Args: - speech_audio: Accumulated speech audio as numpy array - user: User metadata from the Stream call + pcm_data: The PCM audio data to process + participant: Optional participant metadata """ - if self._is_closed: - logger.debug("connection is closed, ignoring audio") - return None + if self.closed: + logger.warning("Wizper STT is closed, ignoring audio") + return if pcm_data.samples.size == 0: logger.debug("No audio data to process") - return None + return try: logger.debug( @@ -154,8 +152,9 @@ async def _process_audio_impl( if "text" in result: text = result["text"].strip() if text: + response_metadata = TranscriptResponse() self._emit_transcript_event( - text, user_metadata, {"chunks": result.get("chunks", [])} + text, participant, response_metadata ) finally: # Clean up temporary file @@ -164,17 +163,15 @@ async def _process_audio_impl( except OSError: pass - # Return None for asynchronous mode - events are emitted when they arrive - return None - except Exception as e: - logger.error(f"FalWizper processing error: {str(e)}") - self._emit_error_event(e, "FalWizper processing") - return None + logger.error(f"Wizper processing error: {str(e)}") + self._emit_error_event(e, "Wizper processing") async def close(self): - """Close the STT service and release any resources.""" - if self._is_closed: + """Close the Wizper STT service and release any resources.""" + if self.closed: + logger.debug("Wizper STT service already closed") return - self._is_closed = True - logger.info("FalWizperSTT closed") + + logger.info("Closing Wizper STT service") + await super().close() diff --git a/tests/test_stt_base.py b/tests/test_stt_base.py deleted file mode 100644 index 83e79489..00000000 --- a/tests/test_stt_base.py +++ /dev/null @@ -1,355 +0,0 @@ -""" -Test the base STT class consistency improvements. -""" - -import asyncio -import os -import pytest -from unittest.mock import Mock - -from dotenv import load_dotenv - -from vision_agents.core.stt.stt import STT -from vision_agents.core.stt.events import STTTranscriptEvent, STTPartialTranscriptEvent, STTErrorEvent -from getstream.video.rtc.track_util import PcmData -from vision_agents.core.agents import Agent -from vision_agents.core.edge.types import User, Participant -from vision_agents.plugins import getstream, openai, deepgram -import numpy as np - -from .base_test import BaseTest - -load_dotenv() - - -class MockSTT(STT): - """Mock STT implementation for testing base class functionality.""" - - def __init__(self): - super().__init__() - self.process_audio_impl_called = False - self.process_audio_impl_result = None - - async def _process_audio_impl(self, pcm_data, user_metadata=None): - self.process_audio_impl_called = True - return self.process_audio_impl_result - - async def close(self): - self._is_closed = True - - -@pytest.fixture -async def mock_stt(): - """Create MockSTT instance in async context.""" - return MockSTT() - - -@pytest.fixture -def valid_pcm_data(): - """Create valid PCM data for testing.""" - samples = np.random.randint(-1000, 1000, size=1000, dtype=np.int16) - return PcmData(samples=samples, sample_rate=16000, format="s16") - - -@pytest.mark.asyncio -async def test_validate_pcm_data_valid(mock_stt, valid_pcm_data): - """Test that valid PCM data passes validation.""" - assert mock_stt._validate_pcm_data(valid_pcm_data) is True - - -@pytest.mark.asyncio -async def test_validate_pcm_data_none(mock_stt): - """Test that None PCM data fails validation.""" - assert mock_stt._validate_pcm_data(None) is False - - -@pytest.mark.asyncio -async def test_validate_pcm_data_no_samples(mock_stt): - """Test that PCM data without samples fails validation.""" - pcm_data = Mock() - pcm_data.samples = None - pcm_data.sample_rate = 16000 - assert mock_stt._validate_pcm_data(pcm_data) is False - - -@pytest.mark.asyncio -async def test_validate_pcm_data_invalid_sample_rate(mock_stt): - """Test that PCM data with invalid sample rate fails validation.""" - pcm_data = Mock() - pcm_data.samples = np.array([1, 2, 3]) - pcm_data.sample_rate = 0 - assert mock_stt._validate_pcm_data(pcm_data) is False - - -@pytest.mark.asyncio -async def test_validate_pcm_data_empty_samples(mock_stt): - """Test that PCM data with empty samples fails validation.""" - pcm_data = Mock() - pcm_data.samples = np.array([]) - pcm_data.sample_rate = 16000 - assert mock_stt._validate_pcm_data(pcm_data) is False - - -@pytest.mark.asyncio -async def test_emit_transcript_event(mock_stt): - """Test that transcript events are emitted correctly.""" - # Set up event listener - transcript_events = [] - - @mock_stt.events.subscribe - async def on_transcript(event: STTTranscriptEvent): - transcript_events.append(event) - - # Emit a transcript event - text = "Hello world" - user_metadata = {"user_id": "123"} - metadata = {"confidence": 0.95, "processing_time_ms": 100} - - mock_stt._emit_transcript_event(text, user_metadata, metadata) - - # Wait for event processing - await mock_stt.events.wait(timeout=1.0) - - # Verify event was emitted - assert len(transcript_events) == 1 - event = transcript_events[0] - assert event.text == text - assert event.user_metadata == user_metadata - assert event.confidence == metadata["confidence"] - assert event.processing_time_ms == metadata["processing_time_ms"] - - -@pytest.mark.asyncio -async def test_emit_partial_transcript_event(mock_stt): - """Test that partial transcript events are emitted correctly.""" - # Set up event listener - partial_events = [] - - @mock_stt.events.subscribe - async def on_partial_transcript(event: STTPartialTranscriptEvent): - partial_events.append(event) - - # Emit a partial transcript event - text = "Hello" - user_metadata = {"user_id": "123"} - metadata = {"confidence": 0.8} - - mock_stt._emit_partial_transcript_event(text, user_metadata, metadata) - - # Wait for event processing - await mock_stt.events.wait(timeout=1.0) - - # Verify event was emitted - assert len(partial_events) == 1 - event = partial_events[0] - assert event.text == text - assert event.user_metadata == user_metadata - assert event.confidence == metadata["confidence"] - - -@pytest.mark.asyncio -async def test_emit_error_event(mock_stt): - """Test that error events are emitted correctly.""" - # Set up event listener - error_events = [] - - @mock_stt.events.subscribe - async def on_error(event: STTErrorEvent): - error_events.append(event) - - # Emit an error event - test_error = Exception("Test error") - mock_stt._emit_error_event(test_error, "test context") - - # Wait for event processing - await mock_stt.events.wait(timeout=1.0) - - # Verify event was emitted - assert len(error_events) == 1 - event = error_events[0] - assert event.error == test_error - assert event.context == "test context" - - -@pytest.mark.asyncio -async def test_process_audio_with_invalid_data(mock_stt): - """Test that process_audio handles invalid data gracefully.""" - # Try to process None data - await mock_stt.process_audio(None) - - # Verify that _process_audio_impl was not called - assert mock_stt.process_audio_impl_called is False - - -@pytest.mark.asyncio -async def test_process_audio_with_valid_data(mock_stt, valid_pcm_data): - """Test that process_audio processes valid data correctly.""" - # Set up mock result - mock_stt.process_audio_impl_result = [(True, "Hello world", {"confidence": 0.95})] - - # Set up event listener - transcript_events = [] - - @mock_stt.events.subscribe - async def on_transcript(event: STTTranscriptEvent): - transcript_events.append(event) - - # Process audio - user_metadata = {"user_id": "123"} - await mock_stt.process_audio(valid_pcm_data, user_metadata) - - # Wait for event processing - await mock_stt.events.wait(timeout=1.0) - - # Verify that _process_audio_impl was called - assert mock_stt.process_audio_impl_called is True - - # Verify that transcript event was emitted - assert len(transcript_events) == 1 - event = transcript_events[0] - assert event.text == "Hello world" - assert event.user_metadata == user_metadata - assert event.confidence == 0.95 - assert event.processing_time_ms is not None # Should be added by base class - - -@pytest.mark.asyncio -async def test_process_audio_when_closed(mock_stt, valid_pcm_data): - """Test that process_audio ignores requests when STT is closed.""" - # Close the STT - await mock_stt.close() - - # Try to process audio - await mock_stt.process_audio(valid_pcm_data) - - # Verify that _process_audio_impl was not called - assert mock_stt.process_audio_impl_called is False - - -@pytest.mark.asyncio -async def test_process_audio_handles_exceptions(mock_stt, valid_pcm_data): - """Test that process_audio handles exceptions from _process_audio_impl.""" - - # Set up mock to raise an exception - class MockSTTWithException(MockSTT): - async def _process_audio_impl(self, pcm_data, user_metadata=None): - raise Exception("Test exception") - - mock_stt_with_exception = MockSTTWithException() - - # Set up error event listener - error_events = [] - - @mock_stt_with_exception.events.subscribe - async def on_error(event: STTErrorEvent): - error_events.append(event) - - # Process audio (should not raise exception) - await mock_stt_with_exception.process_audio(valid_pcm_data) - - # Wait for event processing - await mock_stt_with_exception.events.wait(timeout=1.0) - - # Verify that error event was emitted - assert len(error_events) == 1 - event = error_events[0] - assert str(event.error) == "Test exception" - - -# ============================================================================ -# Integration Tests -# ============================================================================ - -class TestSTTIntegration(BaseTest): - """Integration tests for STT with real components.""" - - @pytest.mark.integration - async def test_agent_stt_only_without_tts(self, mia_audio_16khz): - """ - Real integration test: Agent with STT but no TTS. - - Uses real components (Deepgram STT, OpenAI LLM, Stream Edge) - to verify STT-only agents work end-to-end. - - This test verifies: - - Agent can be created with STT but without TTS - - Agent correctly identifies need for audio input - - Agent does not publish audio track (no TTS) - - Audio flows through to STT - - STT transcript events are emitted - - Transcripts are added to conversation - """ - # Skip if required API keys are not present - required_keys = ["DEEPGRAM_API_KEY", "OPENAI_API_KEY", "STREAM_API_KEY"] - missing_keys = [key for key in required_keys if not os.getenv(key)] - if missing_keys: - pytest.skip(f"Missing required API keys: {', '.join(missing_keys)}") - - - - edge = getstream.Edge() - llm = openai.LLM(model="gpt-4o-mini") - # Create STT with correct sample rate to match our test audio - stt = deepgram.STT(sample_rate=16000) - - # Create agent with STT but explicitly NO TTS - agent = Agent( - edge=edge, - agent_user=User(name="STT Test Agent", id="stt_agent"), - llm=llm, - stt=stt, - tts=None, # ← KEY: No TTS - this is what we're testing - instructions="You are a test agent for STT-only support.", - ) - - # Test 1: Verify agent needs audio input (because STT is present) - assert agent._needs_audio_or_video_input() is True, \ - "Agent with STT should need audio input" - - # Test 2: Verify agent does NOT publish audio (because TTS is None) - assert agent.publish_audio is False, \ - "Agent without TTS should not publish audio" - - # Test 3: Set up event listeners to capture transcript - transcript_events = [] - - @agent.events.subscribe - async def on_transcript(event: STTTranscriptEvent): - transcript_events.append(event) - - # Test 4: Create a test participant (user sending audio) - test_user = User(name="Test User", id="test_user") - test_participant = Participant( - original=test_user, # The original user object - user_id="test_user", # User ID - ) - - # Test 5: Send real audio through the agent's audio processing path - # This simulates what happens when a user speaks in a call - await agent._reply_to_audio(mia_audio_16khz, test_participant) - - # Test 6: Wait for STT to process and emit transcript - # Real STT takes time to process audio and establish connection - await asyncio.sleep(5.0) - - # Test 7: Verify that transcript event was emitted - assert len(transcript_events) > 0, \ - "STT should have emitted at least one transcript event" - - # Test 8: Verify transcript has content - first_transcript = transcript_events[0] - assert first_transcript.text is not None, \ - "Transcript should have text content" - assert len(first_transcript.text) > 0, \ - "Transcript text should not be empty" - - # Test 9: Verify user metadata is present - assert first_transcript.user_metadata is not None, \ - "Transcript should have user metadata" - - # Log the transcript for debugging - print(f"✅ STT transcribed: '{first_transcript.text}'") - - # Test 10: Clean up - await stt.close() - await agent.close()