Skip to content

Commit 2c54ccd

Browse files
Adds support for basic turn detection (#2)
* Continued improvements to turn interface * cleanup * simplify more * Add instructions for cursor * Remove the need for TurnAdapter, pass PCM data directly * Implement FAL smart-turn integration for turn detection - Add FalTurnDetection class implementing the TurnDetection protocol - Integrate with FAL AI's smart-turn model for audio-based turn detection - Buffer audio data, upload to FAL API, and process predictions - Wire into existing Agent class with event-driven architecture - Add fal-client dependency and comprehensive example - Create example_turn_detection.py demonstrating usage - Include documentation and configuration options Features: - Real-time audio buffering and processing - Configurable prediction thresholds and buffer durations - Event emission for turn start/end detection - Temporary file management and cleanup - Integration with Stream's WebRTC audio pipeline * Simple example using FAL/Smart-Turn --------- Co-authored-by: Thierry Schellenbach <thierry@getstream.io>
1 parent a554325 commit 2c54ccd

File tree

15 files changed

+1278
-204
lines changed

15 files changed

+1278
-204
lines changed

agents/__init__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,20 @@
55
It uses the stream-py package for STT and TTS services.
66
"""
77

8-
from .agents import Agent, Tool, PreProcessor, LLM, STT, TTS, STS, TurnDetection
8+
from .agents import Agent, Tool, PreProcessor, LLM
9+
10+
# Import STT, TTS, STS from stream-py package (they are imported in agents.py)
11+
try:
12+
from getstream.plugins.common.stt import STT
13+
from getstream.plugins.common.tts import TTS
14+
from getstream.plugins.common.sts import STS
15+
except ImportError:
16+
# Fallback if stream-py is not installed
17+
STT = None
18+
TTS = None
19+
STS = None
20+
21+
# TurnDetectionAdapter removed - use TurnDetection protocol directly
922

1023
__all__ = [
1124
"Agent",
@@ -15,7 +28,8 @@
1528
"STT",
1629
"TTS",
1730
"STS",
18-
"TurnDetection",
1931
]
2032

2133
__version__ = "0.1.0"
34+
35+
# To resume this session: cursor-agent --resume=a524bb8f-60f1-4520-9abf-0ee07d3d1f12

agents/agents.py

Lines changed: 151 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from getstream.plugins.common.tts import TTS
3030
from getstream.plugins.common.vad import VAD
3131
from getstream.plugins.common.sts import STS
32+
from turn_detection.turn_detection import TurnDetection, TurnEvent
3233

3334

3435
class Tool(Protocol):
@@ -58,14 +59,6 @@ async def generate(self, prompt: str, **kwargs) -> str:
5859
# STT and TTS are now imported directly from stream-py package
5960

6061

61-
class TurnDetection(Protocol):
62-
"""Protocol for turn detection services."""
63-
64-
def detect_turn(self, audio_data: bytes) -> bool:
65-
"""Detect if it's the agent's turn to speak."""
66-
...
67-
68-
6962
class ImageProcessor(Protocol):
7063
"""Protocol for image processors."""
7164

@@ -402,16 +395,22 @@ async def join(
402395
self.logger.info("🎥 Video track initialized for transformation publishing")
403396

404397
try:
405-
# Configure subscription based on whether video processing is enabled
398+
# Configure subscription based on what features are enabled
406399
subscription_config = None
400+
track_types = []
401+
402+
# Subscribe to audio if we have STT or turn detection
403+
if self.stt or self.turn_detection:
404+
track_types.append(TrackType.TRACK_TYPE_AUDIO)
405+
406+
# Subscribe to video if we have image processors or video transformer
407407
if self.image_processors or self.video_transformer:
408+
track_types.append(TrackType.TRACK_TYPE_VIDEO)
409+
410+
# Create subscription config if we need any tracks
411+
if track_types:
408412
subscription_config = SubscriptionConfig(
409-
default=TrackSubscriptionConfig(
410-
track_types=[
411-
TrackType.TRACK_TYPE_VIDEO,
412-
TrackType.TRACK_TYPE_AUDIO,
413-
]
414-
)
413+
default=TrackSubscriptionConfig(track_types=track_types)
415414
)
416415

417416
async with await rtc.join(
@@ -435,6 +434,10 @@ async def join(
435434
# Set up event handlers
436435
await self._setup_event_handlers()
437436

437+
# Set up turn detection callbacks (start happens later)
438+
if self.turn_detection:
439+
self._setup_turn_detection_callbacks()
440+
438441
# Execute callback after full initialization to prevent race conditions
439442
if on_connected_callback and not self._callback_executed:
440443
self._callback_executed = True
@@ -455,9 +458,21 @@ async def safe_callback():
455458

456459
asyncio.create_task(safe_callback())
457460

461+
# Start turn detection last, after event handlers and track subscriptions are ready
462+
if self.turn_detection:
463+
try:
464+
self.turn_detection.start()
465+
self.logger.info("Turn detection started")
466+
except Exception as e:
467+
self.logger.error(
468+
f"Failed to start turn detection: {e}", exc_info=True
469+
)
470+
458471
try:
459472
self.logger.info("🎧 Agent is active - press Ctrl+C to stop")
473+
self.logger.debug("Waiting for connection to end...")
460474
await connection.wait()
475+
self.logger.info("Connection ended normally")
461476
except Exception as e:
462477
self.logger.error(f"❌ Error during agent operation: {e}")
463478
self.logger.error(traceback.format_exc())
@@ -478,6 +493,16 @@ async def safe_callback():
478493
self._is_running = False
479494
self._connection = None
480495

496+
# Stop turn detection if available
497+
if self.turn_detection:
498+
try:
499+
self.turn_detection.stop()
500+
self.logger.info("Turn detection stopped")
501+
except Exception as e:
502+
self.logger.warning(
503+
f"Error stopping turn detection: {e}", exc_info=True
504+
)
505+
481506
if self.stt:
482507
try:
483508
await self.stt.close()
@@ -792,35 +817,84 @@ async def on_track_published(event):
792817
if user_id and user_id != self.bot_id:
793818
self.logger.info(f"👋 New participant joined: {user_id}")
794819
await self._handle_new_participant(user_id)
820+
795821
elif user_id == self.bot_id:
796-
self.logger.debug(f'Not subscribing to track: user_id: "{user_id}"')
822+
self.logger.debug(f"Skipping bot's own track: {user_id}")
797823
except Exception as e:
798824
self.logger.error(f"❌ Error handling track published event: {e}")
799825
self.logger.error(traceback.format_exc())
800826

801-
# Handle audio data for STT using Stream SDK pattern
827+
# Handle audio data for STT and turn detection using Stream SDK pattern
802828
@self._connection.on("audio")
803829
async def on_audio_received(pcm, user):
804830
"""Handle incoming audio data from participants."""
805831
try:
806-
if self.stt and user and user != self.bot_id:
832+
# Skip if it's the bot's own audio
833+
if user is not None and (
834+
(hasattr(user, "user_id") and user.user_id == self.bot_id)
835+
or (isinstance(user, str) and user == self.bot_id)
836+
):
837+
return
838+
839+
# Log audio arrival for diagnostics
840+
try:
841+
length = len(pcm.data) if hasattr(pcm, "data") else len(pcm)
842+
except Exception:
843+
length = -1
844+
uid = user.user_id if hasattr(user, "user_id") else str(user)
845+
self.logger.debug(f"🔈 Audio event from {uid}: {length} bytes")
846+
847+
# Process for turn detection (independent of STT)
848+
if self.turn_detection and user:
849+
user_id = user.user_id if hasattr(user, "user_id") else str(user)
850+
851+
# Process audio for turn detection
852+
try:
853+
await self.turn_detection.process_audio(
854+
pcm,
855+
user_id,
856+
metadata={"timestamp": asyncio.get_event_loop().time()},
857+
)
858+
except Exception as e:
859+
self.logger.error(
860+
f"Turn detection process_audio error: {e}", exc_info=True
861+
)
862+
863+
# Also process for STT if available
864+
if self.stt and user:
807865
await self._handle_audio_input(pcm, user)
866+
808867
except Exception as e:
809868
self.logger.error(f"Error handling audio received event: {e}")
810869
self.logger.error(traceback.format_exc())
811870

812871
# Set up video track handler if image processors or video transformer are configured
813-
if (self.image_processors or self.video_transformer) and self._connection:
872+
if self._connection:
814873

815874
def on_track_added(track_id, track_type, user):
816875
user_id = user.user_id if user else "unknown"
817876
self.logger.info(
818877
f"🎬 New track detected: {track_id} ({track_type}) from {user_id}"
819878
)
820-
if track_type == "video":
879+
# Handle video tracks
880+
if (
881+
track_type == "video"
882+
or getattr(track_type, "value", track_type)
883+
== TrackType.TRACK_TYPE_VIDEO
884+
or track_type == 2
885+
) and (self.image_processors or self.video_transformer):
821886
asyncio.create_task(
822887
self._process_video_track(track_id, track_type, user)
823888
)
889+
# Handle audio tracks for turn detection via unified interface if available
890+
elif self.turn_detection and (
891+
track_type == "audio"
892+
or getattr(track_type, "value", track_type)
893+
== TrackType.TRACK_TYPE_AUDIO
894+
or track_type == 1
895+
):
896+
# Turn detection will automatically track participants via process_audio
897+
pass
824898

825899
self._connection.on("track_added", on_track_added)
826900

@@ -857,15 +931,12 @@ def on_tts_error(error, user=None, metadata=None):
857931
self.logger.error(traceback.format_exc())
858932

859933
async def _handle_audio_input(self, pcm_data, user) -> None:
860-
"""Handle incoming audio data from Stream WebRTC connection."""
934+
"""Handle incoming audio data from Stream WebRTC connection for STT."""
861935
if not self.stt:
862936
return
863937

864938
try:
865-
# Check if it's our turn to respond (if turn detection is configured)
866-
if self.turn_detection and hasattr(pcm_data, "data"):
867-
if not self.turn_detection.detect_turn(pcm_data.data):
868-
return
939+
# STT processes all audio continuously for best quality
869940

870941
# Set up event listeners for transcription results (one-time setup)
871942
if not hasattr(self, "_stt_setup"):
@@ -880,6 +951,7 @@ async def _handle_audio_input(self, pcm_data, user) -> None:
880951
await self._process_audio_with_vad(pcm_data, user)
881952
else:
882953
# Without VAD: Process all audio directly through STT
954+
# STT needs continuous audio stream for services like Deepgram
883955
await self.stt.process_audio(pcm_data, user)
884956

885957
except Exception as e:
@@ -984,6 +1056,16 @@ async def _on_tts_error(self, error, user=None, metadata=None):
9841056
async def _process_transcription(self, text: str, user=None) -> None:
9851057
"""Process a complete transcription and generate response."""
9861058
try:
1059+
# Check if it's the agent's turn to respond
1060+
if self.turn_detection:
1061+
# Use turn detection callbacks/events rather than polling detect_turn() with transcripts
1062+
# The turn detection system will signal via callbacks when it's time to respond
1063+
if not self._should_respond_based_on_turn_detection():
1064+
self.logger.debug(
1065+
f"Turn detection: Not agent's turn to respond to: {text[:50]}..."
1066+
)
1067+
return
1068+
9871069
# Process with pre-processors
9881070
processed_data = text
9891071
for processor in self.pre_processors:
@@ -1113,3 +1195,47 @@ async def stop(self) -> None:
11131195
except Exception as e:
11141196
self.logger.error(f"Error during agent cleanup: {e}")
11151197
self.logger.error(traceback.format_exc())
1198+
1199+
def _setup_turn_detection_callbacks(self) -> None:
1200+
"""Set up turn detection callbacks using the event system."""
1201+
try:
1202+
# Set up standard turn detection events
1203+
self.turn_detection.on(TurnEvent.TURN_STARTED.value, self._on_turn_started)
1204+
self.turn_detection.on(TurnEvent.TURN_ENDED.value, self._on_turn_ended)
1205+
self.turn_detection.on(
1206+
TurnEvent.SPEECH_STARTED.value, self._on_speech_started_td
1207+
)
1208+
self.turn_detection.on(
1209+
TurnEvent.SPEECH_ENDED.value, self._on_speech_ended_td
1210+
)
1211+
1212+
except Exception as e:
1213+
self.logger.warning(f"Error setting up turn detection callbacks: {e}")
1214+
1215+
def _should_respond_based_on_turn_detection(self) -> bool:
1216+
"""Check if the agent should respond based on current turn detection state."""
1217+
try:
1218+
# Check if turn detection is currently active
1219+
return self.turn_detection.is_detecting()
1220+
1221+
except Exception as e:
1222+
self.logger.debug(
1223+
f"Turn detection check error (defaulting to allow response): {e}"
1224+
)
1225+
return True # Default to allowing response on error
1226+
1227+
def _on_turn_started(self, event_data) -> None:
1228+
"""Handle when a participant starts their turn."""
1229+
self.logger.info("Turn started - participant speaking")
1230+
1231+
def _on_turn_ended(self, event_data) -> None:
1232+
"""Handle when a participant ends their turn."""
1233+
self.logger.info("Turn ended - agent may respond")
1234+
1235+
def _on_speech_started_td(self, event_data) -> None:
1236+
"""Handle speech start from turn detection."""
1237+
self.logger.debug("Turn detection: Speech started")
1238+
1239+
def _on_speech_ended_td(self, event_data) -> None:
1240+
"""Handle speech end from turn detection."""
1241+
self.logger.debug("Turn detection: Speech ended")

agents/agents2.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,14 @@ def get_subscription_config(self):
134134
]
135135
)
136136

137+
def get_subscription_config(self):
138+
return TrackSubscriptionConfig(
139+
track_types=[
140+
TrackType.TRACK_TYPE_VIDEO,
141+
TrackType.TRACK_TYPE_AUDIO,
142+
]
143+
)
144+
137145
async def join(self, call) -> None:
138146
"""Join a Stream video call."""
139147
if self._is_running:
@@ -177,8 +185,10 @@ async def join(self, call) -> None:
177185
# Set up video track if available
178186
if self.publish_video:
179187
await connection.add_tracks(video=self._video_track)
188+
180189
self.logger.debug("🎥 Agent ready to publish video")
181190

191+
182192
# Set up STS audio forwarding if in STS mode
183193
if self.sts_mode and self._sts_connection:
184194
self.logger.info("🎥 STS audio. Forward from openAI to Stream")
@@ -604,6 +614,36 @@ async def close(self):
604614
finally:
605615
self._interval_task = None
606616

617+
async def close(self):
618+
"""Clean up all connections and resources."""
619+
self._is_running = False
620+
621+
if self._sts_connection:
622+
await self._sts_connection.__aexit__(None, None, None)
623+
self._sts_connection = None
624+
625+
if self._connection:
626+
await self._connection.__aexit__(None, None, None)
627+
self._connection = None
628+
629+
if self.stt:
630+
await self.stt.close()
631+
632+
if self.tts:
633+
await self.tts.close()
634+
635+
if self._audio_track:
636+
self._audio_track.stop()
637+
self._audio_track = None
638+
639+
if self._video_track:
640+
self._video_track.stop()
641+
self._video_track = None
642+
643+
if self._interval_task:
644+
self._interval_task.cancel()
645+
self._interval_task = None
646+
607647
def create_user(self):
608648
"""Create user - placeholder for any user setup logic."""
609649
pass

0 commit comments

Comments
 (0)