2929from getstream .plugins .common .tts import TTS
3030from getstream .plugins .common .vad import VAD
3131from getstream .plugins .common .sts import STS
32+ from turn_detection .turn_detection import TurnDetection , TurnEvent
3233
3334
3435class Tool (Protocol ):
@@ -58,14 +59,6 @@ async def generate(self, prompt: str, **kwargs) -> str:
5859# STT and TTS are now imported directly from stream-py package
5960
6061
61- class TurnDetection (Protocol ):
62- """Protocol for turn detection services."""
63-
64- def detect_turn (self , audio_data : bytes ) -> bool :
65- """Detect if it's the agent's turn to speak."""
66- ...
67-
68-
6962class ImageProcessor (Protocol ):
7063 """Protocol for image processors."""
7164
@@ -402,16 +395,22 @@ async def join(
402395 self .logger .info ("🎥 Video track initialized for transformation publishing" )
403396
404397 try :
405- # Configure subscription based on whether video processing is enabled
398+ # Configure subscription based on what features are enabled
406399 subscription_config = None
400+ track_types = []
401+
402+ # Subscribe to audio if we have STT or turn detection
403+ if self .stt or self .turn_detection :
404+ track_types .append (TrackType .TRACK_TYPE_AUDIO )
405+
406+ # Subscribe to video if we have image processors or video transformer
407407 if self .image_processors or self .video_transformer :
408+ track_types .append (TrackType .TRACK_TYPE_VIDEO )
409+
410+ # Create subscription config if we need any tracks
411+ if track_types :
408412 subscription_config = SubscriptionConfig (
409- default = TrackSubscriptionConfig (
410- track_types = [
411- TrackType .TRACK_TYPE_VIDEO ,
412- TrackType .TRACK_TYPE_AUDIO ,
413- ]
414- )
413+ default = TrackSubscriptionConfig (track_types = track_types )
415414 )
416415
417416 async with await rtc .join (
@@ -435,6 +434,10 @@ async def join(
435434 # Set up event handlers
436435 await self ._setup_event_handlers ()
437436
437+ # Set up turn detection callbacks (start happens later)
438+ if self .turn_detection :
439+ self ._setup_turn_detection_callbacks ()
440+
438441 # Execute callback after full initialization to prevent race conditions
439442 if on_connected_callback and not self ._callback_executed :
440443 self ._callback_executed = True
@@ -455,9 +458,21 @@ async def safe_callback():
455458
456459 asyncio .create_task (safe_callback ())
457460
461+ # Start turn detection last, after event handlers and track subscriptions are ready
462+ if self .turn_detection :
463+ try :
464+ self .turn_detection .start ()
465+ self .logger .info ("Turn detection started" )
466+ except Exception as e :
467+ self .logger .error (
468+ f"Failed to start turn detection: { e } " , exc_info = True
469+ )
470+
458471 try :
459472 self .logger .info ("🎧 Agent is active - press Ctrl+C to stop" )
473+ self .logger .debug ("Waiting for connection to end..." )
460474 await connection .wait ()
475+ self .logger .info ("Connection ended normally" )
461476 except Exception as e :
462477 self .logger .error (f"❌ Error during agent operation: { e } " )
463478 self .logger .error (traceback .format_exc ())
@@ -478,6 +493,16 @@ async def safe_callback():
478493 self ._is_running = False
479494 self ._connection = None
480495
496+ # Stop turn detection if available
497+ if self .turn_detection :
498+ try :
499+ self .turn_detection .stop ()
500+ self .logger .info ("Turn detection stopped" )
501+ except Exception as e :
502+ self .logger .warning (
503+ f"Error stopping turn detection: { e } " , exc_info = True
504+ )
505+
481506 if self .stt :
482507 try :
483508 await self .stt .close ()
@@ -792,35 +817,84 @@ async def on_track_published(event):
792817 if user_id and user_id != self .bot_id :
793818 self .logger .info (f"👋 New participant joined: { user_id } " )
794819 await self ._handle_new_participant (user_id )
820+
795821 elif user_id == self .bot_id :
796- self .logger .debug (f'Not subscribing to track: user_id: " { user_id } "' )
822+ self .logger .debug (f"Skipping bot's own track: { user_id } " )
797823 except Exception as e :
798824 self .logger .error (f"❌ Error handling track published event: { e } " )
799825 self .logger .error (traceback .format_exc ())
800826
801- # Handle audio data for STT using Stream SDK pattern
827+ # Handle audio data for STT and turn detection using Stream SDK pattern
802828 @self ._connection .on ("audio" )
803829 async def on_audio_received (pcm , user ):
804830 """Handle incoming audio data from participants."""
805831 try :
806- if self .stt and user and user != self .bot_id :
832+ # Skip if it's the bot's own audio
833+ if user is not None and (
834+ (hasattr (user , "user_id" ) and user .user_id == self .bot_id )
835+ or (isinstance (user , str ) and user == self .bot_id )
836+ ):
837+ return
838+
839+ # Log audio arrival for diagnostics
840+ try :
841+ length = len (pcm .data ) if hasattr (pcm , "data" ) else len (pcm )
842+ except Exception :
843+ length = - 1
844+ uid = user .user_id if hasattr (user , "user_id" ) else str (user )
845+ self .logger .debug (f"🔈 Audio event from { uid } : { length } bytes" )
846+
847+ # Process for turn detection (independent of STT)
848+ if self .turn_detection and user :
849+ user_id = user .user_id if hasattr (user , "user_id" ) else str (user )
850+
851+ # Process audio for turn detection
852+ try :
853+ await self .turn_detection .process_audio (
854+ pcm ,
855+ user_id ,
856+ metadata = {"timestamp" : asyncio .get_event_loop ().time ()},
857+ )
858+ except Exception as e :
859+ self .logger .error (
860+ f"Turn detection process_audio error: { e } " , exc_info = True
861+ )
862+
863+ # Also process for STT if available
864+ if self .stt and user :
807865 await self ._handle_audio_input (pcm , user )
866+
808867 except Exception as e :
809868 self .logger .error (f"Error handling audio received event: { e } " )
810869 self .logger .error (traceback .format_exc ())
811870
812871 # Set up video track handler if image processors or video transformer are configured
813- if ( self . image_processors or self . video_transformer ) and self ._connection :
872+ if self ._connection :
814873
815874 def on_track_added (track_id , track_type , user ):
816875 user_id = user .user_id if user else "unknown"
817876 self .logger .info (
818877 f"🎬 New track detected: { track_id } ({ track_type } ) from { user_id } "
819878 )
820- if track_type == "video" :
879+ # Handle video tracks
880+ if (
881+ track_type == "video"
882+ or getattr (track_type , "value" , track_type )
883+ == TrackType .TRACK_TYPE_VIDEO
884+ or track_type == 2
885+ ) and (self .image_processors or self .video_transformer ):
821886 asyncio .create_task (
822887 self ._process_video_track (track_id , track_type , user )
823888 )
889+ # Handle audio tracks for turn detection via unified interface if available
890+ elif self .turn_detection and (
891+ track_type == "audio"
892+ or getattr (track_type , "value" , track_type )
893+ == TrackType .TRACK_TYPE_AUDIO
894+ or track_type == 1
895+ ):
896+ # Turn detection will automatically track participants via process_audio
897+ pass
824898
825899 self ._connection .on ("track_added" , on_track_added )
826900
@@ -857,15 +931,12 @@ def on_tts_error(error, user=None, metadata=None):
857931 self .logger .error (traceback .format_exc ())
858932
859933 async def _handle_audio_input (self , pcm_data , user ) -> None :
860- """Handle incoming audio data from Stream WebRTC connection."""
934+ """Handle incoming audio data from Stream WebRTC connection for STT ."""
861935 if not self .stt :
862936 return
863937
864938 try :
865- # Check if it's our turn to respond (if turn detection is configured)
866- if self .turn_detection and hasattr (pcm_data , "data" ):
867- if not self .turn_detection .detect_turn (pcm_data .data ):
868- return
939+ # STT processes all audio continuously for best quality
869940
870941 # Set up event listeners for transcription results (one-time setup)
871942 if not hasattr (self , "_stt_setup" ):
@@ -880,6 +951,7 @@ async def _handle_audio_input(self, pcm_data, user) -> None:
880951 await self ._process_audio_with_vad (pcm_data , user )
881952 else :
882953 # Without VAD: Process all audio directly through STT
954+ # STT needs continuous audio stream for services like Deepgram
883955 await self .stt .process_audio (pcm_data , user )
884956
885957 except Exception as e :
@@ -984,6 +1056,16 @@ async def _on_tts_error(self, error, user=None, metadata=None):
9841056 async def _process_transcription (self , text : str , user = None ) -> None :
9851057 """Process a complete transcription and generate response."""
9861058 try :
1059+ # Check if it's the agent's turn to respond
1060+ if self .turn_detection :
1061+ # Use turn detection callbacks/events rather than polling detect_turn() with transcripts
1062+ # The turn detection system will signal via callbacks when it's time to respond
1063+ if not self ._should_respond_based_on_turn_detection ():
1064+ self .logger .debug (
1065+ f"Turn detection: Not agent's turn to respond to: { text [:50 ]} ..."
1066+ )
1067+ return
1068+
9871069 # Process with pre-processors
9881070 processed_data = text
9891071 for processor in self .pre_processors :
@@ -1113,3 +1195,47 @@ async def stop(self) -> None:
11131195 except Exception as e :
11141196 self .logger .error (f"Error during agent cleanup: { e } " )
11151197 self .logger .error (traceback .format_exc ())
1198+
1199+ def _setup_turn_detection_callbacks (self ) -> None :
1200+ """Set up turn detection callbacks using the event system."""
1201+ try :
1202+ # Set up standard turn detection events
1203+ self .turn_detection .on (TurnEvent .TURN_STARTED .value , self ._on_turn_started )
1204+ self .turn_detection .on (TurnEvent .TURN_ENDED .value , self ._on_turn_ended )
1205+ self .turn_detection .on (
1206+ TurnEvent .SPEECH_STARTED .value , self ._on_speech_started_td
1207+ )
1208+ self .turn_detection .on (
1209+ TurnEvent .SPEECH_ENDED .value , self ._on_speech_ended_td
1210+ )
1211+
1212+ except Exception as e :
1213+ self .logger .warning (f"Error setting up turn detection callbacks: { e } " )
1214+
1215+ def _should_respond_based_on_turn_detection (self ) -> bool :
1216+ """Check if the agent should respond based on current turn detection state."""
1217+ try :
1218+ # Check if turn detection is currently active
1219+ return self .turn_detection .is_detecting ()
1220+
1221+ except Exception as e :
1222+ self .logger .debug (
1223+ f"Turn detection check error (defaulting to allow response): { e } "
1224+ )
1225+ return True # Default to allowing response on error
1226+
1227+ def _on_turn_started (self , event_data ) -> None :
1228+ """Handle when a participant starts their turn."""
1229+ self .logger .info ("Turn started - participant speaking" )
1230+
1231+ def _on_turn_ended (self , event_data ) -> None :
1232+ """Handle when a participant ends their turn."""
1233+ self .logger .info ("Turn ended - agent may respond" )
1234+
1235+ def _on_speech_started_td (self , event_data ) -> None :
1236+ """Handle speech start from turn detection."""
1237+ self .logger .debug ("Turn detection: Speech started" )
1238+
1239+ def _on_speech_ended_td (self , event_data ) -> None :
1240+ """Handle speech end from turn detection."""
1241+ self .logger .debug ("Turn detection: Speech ended" )
0 commit comments