Skip to content

Commit c09e360

Browse files
committed
ensure user agent is initialized before joining the call
1 parent 1025a42 commit c09e360

File tree

6 files changed

+244
-185
lines changed

6 files changed

+244
-185
lines changed

agents-core/vision_agents/core/agents/agents.py

Lines changed: 88 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def _log_task_exception(task: asyncio.Task):
5050
except Exception:
5151
logger.exception("Error in background task")
5252

53+
5354
class Agent:
5455
"""
5556
Agent class makes it easy to build your own video AI.
@@ -102,6 +103,7 @@ def __init__(
102103
self.instructions = instructions
103104
self.edge = edge
104105
self.agent_user = agent_user
106+
self._agent_user_initialized = False
105107

106108
# only needed in case we spin threads
107109
self._root_span = trace.get_current_span()
@@ -124,15 +126,19 @@ def __init__(
124126
self._call_context_token: CallContextToken | None = None
125127

126128
# Initialize MCP manager if servers are provided
127-
self.mcp_manager = MCPManager(self.mcp_servers, self.llm, self.logger) if self.mcp_servers else None
129+
self.mcp_manager = (
130+
MCPManager(self.mcp_servers, self.llm, self.logger)
131+
if self.mcp_servers
132+
else None
133+
)
128134

129135
# we sync the user talking and the agent responses to the conversation
130136
# because we want to support streaming responses and can have delta updates for both
131137
# user and agent we keep an handle for both
132138
self.conversation: Optional[Conversation] = None
133139
self._user_conversation_handle: Optional[StreamHandle] = None
134140
self._agent_conversation_handle: Optional[StreamHandle] = None
135-
141+
136142
# Track pending transcripts for turn-based response triggering
137143
self._pending_user_transcripts: Dict[str, str] = {}
138144

@@ -153,7 +159,7 @@ def __init__(
153159
self._current_frame = None
154160
self._interval_task = None
155161
self._callback_executed = False
156-
self._track_tasks : Dict[str, asyncio.Task] = {}
162+
self._track_tasks: Dict[str, asyncio.Task] = {}
157163
self._connection: Optional[Connection] = None
158164
self._audio_track: Optional[aiortc.AudioStreamTrack] = None
159165
self._video_track: Optional[VideoStreamTrack] = None
@@ -194,8 +200,9 @@ def subscribe(self, function):
194200
"""
195201
return self.events.subscribe(function)
196202

197-
198203
async def join(self, call: Call) -> "AgentSessionContextManager":
204+
await self.create_user()
205+
199206
# TODO: validation. join can only be called once
200207
with self.tracer.start_as_current_span("join"):
201208
if self._is_running:
@@ -311,9 +318,9 @@ async def close(self):
311318

312319
for processor in self.processors:
313320
processor.close()
314-
321+
315322
# Stop all video forwarders
316-
if hasattr(self, '_video_forwarders'):
323+
if hasattr(self, "_video_forwarders"):
317324
for forwarder in self._video_forwarders:
318325
try:
319326
await forwarder.stop()
@@ -382,16 +389,18 @@ def clear_call_logging_context(self) -> None:
382389
clear_call_context(self._call_context_token)
383390
self._call_context_token = None
384391

385-
async def create_user(self):
386-
"""Create the agent user in the edge provider, if required.
392+
async def create_user(self) -> None:
393+
"""Create the agent user in the edge provider, if required."""
394+
395+
if self._agent_user_initialized:
396+
return None
387397

388-
Returns:
389-
Provider-specific user creation response.
390-
"""
391398
with self.tracer.start_as_current_span("edge.create_user"):
392-
if self.agent_user.id == "":
393-
self.agent_user.id = str(uuid4())
394-
return await self.edge.create_user(self.agent_user)
399+
if not self.agent_user.id:
400+
self.agent_user.id = f"agent-{uuid4()}"
401+
await self.edge.create_user(self.agent_user)
402+
403+
return None
395404

396405
async def _handle_output_text_delta(self, event: LLMResponseChunkEvent):
397406
"""Handle partial LLM response text deltas."""
@@ -499,23 +508,30 @@ async def _on_agent_say(self, event: events.AgentSayEvent):
499508
)
500509
self.logger.error(f"Error in agent say: {e}")
501510

502-
async def say(self, text: str, user_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None):
511+
async def say(
512+
self,
513+
text: str,
514+
user_id: Optional[str] = None,
515+
metadata: Optional[Dict[str, Any]] = None,
516+
):
503517
"""
504518
Make the agent say something using TTS.
505-
519+
506520
This is a convenience method that sends an AgentSayEvent to trigger TTS synthesis.
507-
521+
508522
Args:
509523
text: The text for the agent to say
510524
user_id: Optional user ID for the speech
511525
metadata: Optional metadata to include with the speech
512526
"""
513-
self.events.send(events.AgentSayEvent(
514-
plugin_name="agent",
515-
text=text,
516-
user_id=user_id or self.agent_user.id,
517-
metadata=metadata
518-
))
527+
self.events.send(
528+
events.AgentSayEvent(
529+
plugin_name="agent",
530+
text=text,
531+
user_id=user_id or self.agent_user.id,
532+
metadata=metadata,
533+
)
534+
)
519535

520536
def _setup_turn_detection(self):
521537
if self.turn_detection:
@@ -571,12 +587,11 @@ async def _reply_to_audio(
571587
continue
572588
await processor.process_audio(audio_bytes, participant.user_id)
573589

574-
575590
# when in Realtime mode call the Realtime directly (non-blocking)
576591
if self.realtime_mode and isinstance(self.llm, Realtime):
577592
# TODO: this behaviour should be easy to change in the agent class
578593
asyncio.create_task(self.llm.simple_audio_response(pcm_data))
579-
#task.add_done_callback(lambda t: print(f"Task (send_audio_pcm) error: {t.exception()}"))
594+
# task.add_done_callback(lambda t: print(f"Task (send_audio_pcm) error: {t.exception()}"))
580595
# Process audio through STT
581596
elif self.stt:
582597
self.logger.debug(f"🎵 Processing audio from {participant}")
@@ -591,14 +606,12 @@ async def _process_track(self, track_id: str, track_type: int, participant):
591606
# subscribe to the video track
592607
track = self.edge.add_track_subscriber(track_id)
593608
if not track:
594-
self.logger.error(
595-
f"Failed to subscribe to {track_id}"
596-
)
609+
self.logger.error(f"Failed to subscribe to {track_id}")
597610
return
598611

599612
# Import VideoForwarder
600613
from ..utils.video_forwarder import VideoForwarder
601-
614+
602615
# Create a SHARED VideoForwarder for the RAW incoming track
603616
# This prevents multiple recv() calls competing on the same track
604617
raw_forwarder = VideoForwarder(
@@ -609,9 +622,9 @@ async def _process_track(self, track_id: str, track_type: int, participant):
609622
)
610623
await raw_forwarder.start()
611624
self.logger.info("🎥 Created raw VideoForwarder for track %s", track_id)
612-
625+
613626
# Track forwarders for cleanup
614-
if not hasattr(self, '_video_forwarders'):
627+
if not hasattr(self, "_video_forwarders"):
615628
self._video_forwarders = []
616629
self._video_forwarders.append(raw_forwarder)
617630

@@ -620,7 +633,9 @@ async def _process_track(self, track_id: str, track_type: int, participant):
620633
if self._video_track:
621634
# We have a video publisher (e.g., YOLO processor)
622635
# Create a separate forwarder for the PROCESSED video track
623-
self.logger.info("🎥 Forwarding PROCESSED video frames to Realtime provider")
636+
self.logger.info(
637+
"🎥 Forwarding PROCESSED video frames to Realtime provider"
638+
)
624639
processed_forwarder = VideoForwarder(
625640
self._video_track, # type: ignore[arg-type]
626641
max_buffer=30,
@@ -629,23 +644,28 @@ async def _process_track(self, track_id: str, track_type: int, participant):
629644
)
630645
await processed_forwarder.start()
631646
self._video_forwarders.append(processed_forwarder)
632-
647+
633648
if isinstance(self.llm, Realtime):
634649
# Send PROCESSED frames with the processed forwarder
635-
await self.llm._watch_video_track(self._video_track, shared_forwarder=processed_forwarder)
650+
await self.llm._watch_video_track(
651+
self._video_track, shared_forwarder=processed_forwarder
652+
)
636653
else:
637654
# No video publisher, send raw frames
638655
self.logger.info("🎥 Forwarding RAW video frames to Realtime provider")
639656
if isinstance(self.llm, Realtime):
640-
await self.llm._watch_video_track(track, shared_forwarder=raw_forwarder)
641-
657+
await self.llm._watch_video_track(
658+
track, shared_forwarder=raw_forwarder
659+
)
642660

643661
hasImageProcessers = len(self.image_processors) > 0
644662

645663
# video processors - pass the raw forwarder (they process incoming frames)
646664
for processor in self.video_processors:
647665
try:
648-
await processor.process_video(track, participant.user_id, shared_forwarder=raw_forwarder)
666+
await processor.process_video(
667+
track, participant.user_id, shared_forwarder=raw_forwarder
668+
)
649669
except Exception as e:
650670
self.logger.error(
651671
f"Error in video processor {type(processor).__name__}: {e}"
@@ -654,13 +674,15 @@ async def _process_track(self, track_id: str, track_type: int, participant):
654674
# Use raw forwarder for image processors - only if there are image processors
655675
if not hasImageProcessers:
656676
# No image processors, just keep the connection alive
657-
self.logger.info("No image processors, video processing handled by video processors only")
677+
self.logger.info(
678+
"No image processors, video processing handled by video processors only"
679+
)
658680
return
659-
681+
660682
# Initialize error tracking counters
661683
timeout_errors = 0
662684
consecutive_errors = 0
663-
685+
664686
while True:
665687
try:
666688
# Use the raw forwarder instead of competing for track.recv()
@@ -672,7 +694,6 @@ async def _process_track(self, track_id: str, track_type: int, participant):
672694
consecutive_errors = 0
673695

674696
if hasImageProcessers:
675-
676697
img = video_frame.to_image()
677698

678699
for processor in self.image_processors:
@@ -683,7 +704,6 @@ async def _process_track(self, track_id: str, track_type: int, participant):
683704
f"Error in image processor {type(processor).__name__}: {e}"
684705
)
685706

686-
687707
else:
688708
self.logger.warning("🎥VDP: Received empty frame")
689709
consecutive_errors += 1
@@ -698,14 +718,16 @@ async def _process_track(self, track_id: str, track_type: int, participant):
698718
await asyncio.sleep(backoff_delay)
699719

700720
# Cleanup and logging
701-
self.logger.info(f"🎥VDP: Video processing loop ended for track {track_id} - timeouts: {timeout_errors}, consecutive_errors: {consecutive_errors}")
721+
self.logger.info(
722+
f"🎥VDP: Video processing loop ended for track {track_id} - timeouts: {timeout_errors}, consecutive_errors: {consecutive_errors}"
723+
)
702724

703725
async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None:
704726
"""Handle turn detection events."""
705727
# In realtime mode, the LLM handles turn detection, interruption, and responses itself
706728
if self.realtime_mode:
707729
return
708-
730+
709731
if isinstance(event, TurnStartedEvent):
710732
# Interrupt TTS when user starts speaking (barge-in)
711733
if event.speaker_id and event.speaker_id != self.agent_user.id:
@@ -730,26 +752,28 @@ async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None
730752
self.logger.info(
731753
f"👉 Turn ended - participant {event.speaker_id} finished (duration: {event.duration}, confidence: {event.confidence})"
732754
)
733-
755+
734756
# When turn detection is enabled, trigger LLM response when user's turn ends
735757
# This is the signal that the user has finished speaking and expects a response
736758
if event.speaker_id and event.speaker_id != self.agent_user.id:
737759
# Get the accumulated transcript for this speaker
738760
transcript = self._pending_user_transcripts.get(event.speaker_id, "")
739-
761+
740762
if transcript and transcript.strip():
741-
self.logger.info(f"🤖 Triggering LLM response after turn ended for {event.speaker_id}")
742-
763+
self.logger.info(
764+
f"🤖 Triggering LLM response after turn ended for {event.speaker_id}"
765+
)
766+
743767
# Create participant object if we have metadata
744768
participant = None
745-
if hasattr(event, 'custom') and event.custom:
769+
if hasattr(event, "custom") and event.custom:
746770
# Try to extract participant info from custom metadata
747-
participant = event.custom.get('participant')
748-
771+
participant = event.custom.get("participant")
772+
749773
# Trigger LLM response with the complete transcript
750774
if self.llm:
751775
await self.simple_response(transcript, participant)
752-
776+
753777
# Clear the pending transcript for this speaker
754778
self._pending_user_transcripts[event.speaker_id] = ""
755779

@@ -806,12 +830,12 @@ async def _on_transcript(self, event: STTTranscriptEvent | RealtimeTranscriptEve
806830
)
807831
self.conversation.complete_message(self._user_conversation_handle)
808832
self._user_conversation_handle = None
809-
833+
810834
# In realtime mode, the LLM handles everything itself (STT, turn detection, responses)
811835
# Skip our manual LLM triggering logic
812836
if self.realtime_mode:
813837
return
814-
838+
815839
# Determine how to handle LLM triggering based on turn detection
816840
if self.turn_detection is not None:
817841
# With turn detection: accumulate transcripts and wait for TurnEndedEvent
@@ -821,7 +845,7 @@ async def _on_transcript(self, event: STTTranscriptEvent | RealtimeTranscriptEve
821845
else:
822846
# Append to existing transcript (user might be speaking in chunks)
823847
self._pending_user_transcripts[user_id] += " " + event.text
824-
848+
825849
self.logger.debug(
826850
f"📝 Accumulated transcript for {user_id} (waiting for turn end): "
827851
f"{self._pending_user_transcripts[user_id][:100]}..."
@@ -830,21 +854,21 @@ async def _on_transcript(self, event: STTTranscriptEvent | RealtimeTranscriptEve
830854
# Without turn detection: trigger LLM immediately on transcript completion
831855
# This is the traditional STT -> LLM flow
832856
if self.llm:
833-
self.logger.info("🤖 Triggering LLM response immediately (no turn detection)")
834-
857+
self.logger.info(
858+
"🤖 Triggering LLM response immediately (no turn detection)"
859+
)
860+
835861
# Get participant from event metadata
836862
participant = None
837863
if hasattr(event, "user_metadata"):
838864
participant = event.user_metadata
839-
865+
840866
await self.simple_response(event.text, participant)
841867

842868
async def _on_stt_error(self, error):
843869
"""Handle STT service errors."""
844870
self.logger.error(f"❌ STT Error: {error}")
845871

846-
847-
848872
@property
849873
def realtime_mode(self) -> bool:
850874
"""Check if the agent is in Realtime mode.
@@ -869,8 +893,7 @@ def publish_audio(self) -> bool:
869893

870894
@property
871895
def publish_video(self) -> bool:
872-
"""Whether the agent should publish an outbound video track.
873-
"""
896+
"""Whether the agent should publish an outbound video track."""
874897
return len(self.video_publishers) > 0
875898

876899
def _needs_audio_or_video_input(self) -> bool:
@@ -1000,7 +1023,9 @@ def _prepare_rtc(self):
10001023
else:
10011024
framerate = 48000
10021025
stereo = True # Default to stereo for WebRTC
1003-
self._audio_track = self.edge.create_audio_track(framerate=framerate, stereo=stereo)
1026+
self._audio_track = self.edge.create_audio_track(
1027+
framerate=framerate, stereo=stereo
1028+
)
10041029
if self.tts:
10051030
self.tts.set_output_track(self._audio_track)
10061031

@@ -1012,7 +1037,6 @@ def _prepare_rtc(self):
10121037
self._video_track = video_publisher.publish_video_track()
10131038
self.logger.info("🎥 Video track initialized from video publisher")
10141039

1015-
10161040
def _truncate_for_logging(self, obj, max_length=200):
10171041
"""Truncate object string representation for logging to prevent spam."""
10181042
obj_str = str(obj)

0 commit comments

Comments
 (0)