Skip to content

Commit 6a05c30

Browse files
authored
Merge pull request #81 from GetStream/hotfix/agent-finish-events-merge
hotfix: events, agent finish
2 parents a9397b3 + dfdcca6 commit 6a05c30

File tree

4 files changed

+54
-22
lines changed

4 files changed

+54
-22
lines changed

agents-core/vision_agents/core/agents/agents.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from opentelemetry.trace import Tracer
1313

1414
from ..edge import sfu_events
15-
from ..edge.events import AudioReceivedEvent, TrackAddedEvent
15+
from ..edge.events import AudioReceivedEvent, TrackAddedEvent, CallEndedEvent
1616
from ..edge.types import Connection, Participant, PcmData, TrackType, User
1717
from ..events.manager import EventManager
1818
from ..llm.events import (
@@ -267,6 +267,28 @@ async def join(self, call: Call) -> "AgentSessionContextManager":
267267

268268
return AgentSessionContextManager(self, self._connection)
269269

270+
async def finish(self):
271+
"""Wait for the call to end gracefully.
272+
Subscribes to the edge transport's `call_ended` event and awaits it. If
273+
no connection is active, returns immediately.
274+
"""
275+
# If connection is None or already closed, return immediately
276+
if not self._connection:
277+
logging.info("🔚 Agent connection already closed, finishing immediately")
278+
return
279+
280+
@self.edge.events.subscribe
281+
async def on_ended(event: CallEndedEvent):
282+
self._is_running = False
283+
284+
while self._is_running:
285+
try:
286+
await asyncio.sleep(0.0001)
287+
except asyncio.CancelledError:
288+
self._is_running = False
289+
290+
await asyncio.shield(self.close())
291+
270292
async def close(self):
271293
"""Clean up all connections and resources.
272294

agents-core/vision_agents/core/events/manager.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import uuid
23
import collections
34
import logging
45
import types
@@ -141,6 +142,7 @@ def __init__(self, ignore_unknown_events: bool = True):
141142
self._processing_task: Optional[asyncio.Task[Any]] = None
142143
self._shutdown = False
143144
self._silent_events: set[type] = set()
145+
self._handler_tasks: Dict[uuid.UUID, asyncio.Task[Any]] = {}
144146

145147
self.register(ExceptionEvent)
146148
self.register(HealthCheckEvent)
@@ -182,6 +184,8 @@ def register(self, event_class, ignore_not_compatible=False):
182184
# raise KeyError(f"{event_class.type} is already registered.")
183185
self._events[event_class.type] = event_class
184186
logger.info(f"Registered new event {event_class} - {event_class.type}")
187+
elif event_class.__name__.endswith('BaseEvent'):
188+
return
185189
elif not ignore_not_compatible:
186190
raise ValueError(f"Provide valid class that ends on '*Event' and 'type' attribute: {event_class}")
187191
else:
@@ -458,6 +462,9 @@ async def wait(self, timeout: float = 10.0):
458462
start_time = asyncio.get_event_loop().time()
459463
while self._queue and (asyncio.get_event_loop().time() - start_time) < timeout:
460464
await asyncio.sleep(0.01)
465+
466+
if self._handler_tasks:
467+
await asyncio.wait(list(self._handler_tasks.values()))
461468

462469
def _start_processing_task(self):
463470
"""Start the background event processing task."""
@@ -488,18 +495,27 @@ async def _process_events_loop(self):
488495
elif cancelled_exc:
489496
raise cancelled_exc
490497
else:
498+
cleanup_ids = set(task_id for task_id, task in self._handler_tasks.items() if task.done())
499+
for task_id in cleanup_ids:
500+
self._handler_tasks.pop(task_id)
491501
await asyncio.sleep(0.0001)
492502

503+
async def _run_handler(self, handler, event):
504+
try:
505+
return await handler(event)
506+
except Exception as exc:
507+
self._queue.appendleft(ExceptionEvent(exc, handler)) # type: ignore[arg-type]
508+
module_name = getattr(handler, '__module__', 'unknown')
509+
logger.exception(f"Error calling handler {handler.__name__} from {module_name} for event {event.type}")
510+
493511
async def _process_single_event(self, event):
494512
"""Process a single event."""
495513
for handler in self._handlers.get(event.type, []):
496-
try:
497-
module_name = getattr(handler, '__module__', 'unknown')
498-
if event.type not in self._silent_events:
499-
logger.info(f"Called handler {handler.__name__} from {module_name} for event {event.type}")
500-
await handler(event)
501-
except Exception as exc:
502-
self._queue.appendleft(ExceptionEvent(exc, handler)) # type: ignore[arg-type]
503-
module_name = getattr(handler, '__module__', 'unknown')
504-
logger.exception(f"Error calling handler {handler.__name__} from {module_name} for event {event.type}")
514+
module_name = getattr(handler, '__module__', 'unknown')
515+
if event.type not in self._silent_events:
516+
logger.info(f"Called handler {handler.__name__} from {module_name} for event {event.type}")
517+
518+
loop = asyncio.get_running_loop()
519+
handler_task = loop.create_task(self._run_handler(handler, event))
520+
self._handler_tasks[uuid.uuid4()] = handler_task
505521

examples/01_simple_agent_example/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@ dependencies = [
2525
"vision-agents-plugins-anthropic" = {path = "../../plugins/anthropic", editable=true}
2626
"vision-agents-plugins-getstream" = {path = "../../plugins/getstream", editable=true}
2727
"vision-agents-plugins-openai" = {path = "../../plugins/openai", editable=true}
28+
"vision-agents-plugins-smart-turn" = {path = "../../plugins/smart_turn", editable=true}
2829

2930
"vision-agents" = {path = "../../agents-core", editable=true}

plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
102102
# SFU might send TrackPublishedEvent before WebRTC processes track_added
103103
track_id = None
104104
timeout = 10.0
105-
poll_interval = 0.01 # 50ms
105+
poll_interval = 0.01 # 10ms
106106
elapsed = 0.0
107107

108108
while elapsed < timeout:
@@ -125,7 +125,7 @@ async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
125125
if track_id:
126126
# Store with correct type from SFU
127127
self._track_map[track_key] = {"track_id": track_id, "published": True}
128-
self.logger.info(f"Track published: {track_type_name} from {user_id}, track_id: {track_id} (waited {elapsed:.2f}s)")
128+
self.logger.info(f"Trackmap published: {track_type_name} from {user_id}, track_id: {track_id} (waited {elapsed:.2f}s)")
129129

130130
# NOW spawn TrackAddedEvent with correct type
131131
self.events.send(events.TrackAddedEvent(
@@ -139,6 +139,9 @@ async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
139139
raise TimeoutError(
140140
f"Timeout waiting for pending track: {track_type_name} ({expected_kind}) from user {user_id}, "
141141
f"session {session_id}. Waited {timeout}s but WebRTC track_added with matching kind was never received."
142+
f"Pending tracks: {self._pending_tracks}\n"
143+
f"Key: {track_key}\n"
144+
f"Track map: {self._track_map}\n"
142145
)
143146

144147
async def _on_track_removed(self, event: sfu_events.ParticipantLeftEvent | sfu_events.TrackUnpublishedEvent):
@@ -230,16 +233,6 @@ async def join(self, agent: "Agent", call: Call) -> StreamConnection:
230233

231234
self._connection = connection
232235

233-
original_on_subscriber_offer = self._connection._on_subscriber_offer # type: ignore[attr-defined]
234-
235-
async def _safe_on_subscriber_offer(event):
236-
if self._connection.subscriber_pc is None: # type: ignore[attr-defined]
237-
self.logger.debug("Ignoring subscriber offer after subscriber_pc teardown")
238-
return
239-
await original_on_subscriber_offer(event)
240-
241-
self._connection._on_subscriber_offer = _safe_on_subscriber_offer # type: ignore[attr-defined]
242-
243236
@self._connection.on("audio")
244237
async def on_audio_received(pcm: PcmData, participant: Participant):
245238
self.events.send(events.AudioReceivedEvent(

0 commit comments

Comments
 (0)