Skip to content

Commit 0453985

Browse files
committed
Iteration №2: use separate ABCs instead of flags
1 parent 4c3b8b7 commit 0453985

File tree

11 files changed

+133
-73
lines changed

11 files changed

+133
-73
lines changed

DEVELOPMENT.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ To see how the agent work open up agents.py
105105

106106
**Video**
107107

108-
* The agent receives the video track, and calls agent.llm._watch_video_track
108+
* The agent receives the video track, and calls agent.llm.watch_video_track
109109
* The LLM uses the VideoForwarder to write the video to a websocket or webrtc connection
110110
* The STS writes the reply on agent.llm.audio_track and the RealtimeTranscriptEvent / RealtimePartialTranscriptEvent
111111

agents-core/vision_agents/core/agents/agents.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import uuid
77
from dataclasses import asdict
8-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
8+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeGuard
99
from uuid import uuid4
1010

1111
import getstream.models
@@ -30,7 +30,7 @@
3030
RealtimeUserSpeechTranscriptionEvent,
3131
RealtimeAgentSpeechTranscriptionEvent,
3232
)
33-
from ..llm.llm import LLM
33+
from ..llm.llm import AudioLLM, LLM, VideoLLM
3434
from ..llm.realtime import Realtime
3535
from ..mcp import MCPBaseServer, MCPManager
3636
from ..processors.base_processor import Processor, ProcessorType, filter_processors
@@ -109,6 +109,18 @@ def default_agent_options():
109109
return AgentOptions(model_dir=_DEFAULT_MODEL_DIR)
110110

111111

112+
def _is_audio_llm(llm: LLM | VideoLLM | AudioLLM) -> TypeGuard[AudioLLM]:
113+
return isinstance(llm, AudioLLM)
114+
115+
116+
def _is_video_llm(llm: LLM | VideoLLM | AudioLLM) -> TypeGuard[VideoLLM]:
117+
return isinstance(llm, VideoLLM)
118+
119+
120+
def _is_realtime_llm(llm: LLM | AudioLLM | VideoLLM | Realtime) -> TypeGuard[Realtime]:
121+
return isinstance(llm, Realtime)
122+
123+
112124
class Agent:
113125
"""
114126
Agent class makes it easy to build your own video AI.
@@ -139,7 +151,7 @@ def __init__(
139151
# edge network for video & audio
140152
edge: "StreamEdge",
141153
# llm, optionally with sts/realtime capabilities
142-
llm: LLM | Realtime,
154+
llm: LLM | AudioLLM | VideoLLM,
143155
# the agent's user info
144156
agent_user: User,
145157
# instructions
@@ -424,7 +436,7 @@ async def _on_tts_audio_write_to_output(event: TTSAudioEvent):
424436

425437
@self.events.subscribe
426438
async def on_stt_transcript_event_create_response(event: STTTranscriptEvent):
427-
if self.llm.handles_audio:
439+
if _is_audio_llm(self.llm):
428440
# There is no need to send the response to the LLM if it handles audio itself.
429441
return
430442

@@ -493,7 +505,7 @@ async def join(self, call: Call) -> "AgentSessionContextManager":
493505

494506
# Ensure Realtime providers are ready before proceeding (they manage their own connection)
495507
self.logger.info(f"🤖 Agent joining call: {call.id}")
496-
if isinstance(self.llm, Realtime):
508+
if _is_realtime_llm(self.llm):
497509
await self.llm.connect()
498510

499511
with self.span("edge.join"):
@@ -812,12 +824,12 @@ async def on_video_track_added(event: TrackAddedEvent):
812824
f"🎥 Track re-added: {track_type_name} ({track_id}), switching to it"
813825
)
814826

815-
if self.llm.handles_video:
827+
if _is_video_llm(self.llm):
816828
# Get the existing forwarder and switch to this track
817829
_, _, forwarder = self._active_video_tracks[track_id]
818830
track = self.edge.add_track_subscriber(track_id)
819831
if track and forwarder:
820-
await self.llm._watch_video_track(
832+
await self.llm.watch_video_track(
821833
track, shared_forwarder=forwarder
822834
)
823835
self._current_video_track_id = track_id
@@ -846,7 +858,7 @@ async def on_video_track_removed(event: TrackRemovedEvent):
846858
self._active_video_tracks.pop(track_id, None)
847859

848860
# If this was the active track, switch to any other available track
849-
if self.llm.handles_video and track_id == self._current_video_track_id:
861+
if _is_video_llm(self.llm) and track_id == self._current_video_track_id:
850862
self.logger.info(
851863
"🎥 Active video track removed, switching to next available"
852864
)
@@ -872,7 +884,7 @@ async def _reply_to_audio(
872884
)
873885

874886
# when in Realtime mode call the Realtime directly (non-blocking)
875-
if self.llm.handles_audio:
887+
if _is_audio_llm(self.llm):
876888
# TODO: this behaviour should be easy to change in the agent class
877889
asyncio.create_task(
878890
self.llm.simple_audio_response(pcm_data, participant)
@@ -908,9 +920,9 @@ async def _switch_to_next_available_track(self) -> None:
908920

909921
# Get the track and forwarder
910922
track = self.edge.add_track_subscriber(track_id)
911-
if track and forwarder and isinstance(self.llm, Realtime):
923+
if track and forwarder and _is_video_llm(self.llm):
912924
# Send to Realtime provider
913-
await self.llm._watch_video_track(track, shared_forwarder=forwarder)
925+
await self.llm.watch_video_track(track, shared_forwarder=forwarder)
914926
self._current_video_track_id = track_id
915927
return
916928
else:
@@ -973,7 +985,7 @@ async def recv(self):
973985
# If Realtime provider supports video, switch to this new track
974986
track_type_name = TrackType.Name(track_type)
975987

976-
if self.llm.handles_video:
988+
if _is_video_llm(self.llm):
977989
if self._video_track:
978990
# We have a video publisher (e.g., YOLO processor)
979991
# Create a separate forwarder for the PROCESSED video track
@@ -989,22 +1001,20 @@ async def recv(self):
9891001
await processed_forwarder.start()
9901002
self._video_forwarders.append(processed_forwarder)
9911003

992-
if isinstance(self.llm, Realtime):
993-
# Send PROCESSED frames with the processed forwarder
994-
await self.llm._watch_video_track(
995-
self._video_track, shared_forwarder=processed_forwarder
996-
)
997-
self._current_video_track_id = track_id
1004+
# Send PROCESSED frames with the processed forwarder
1005+
await self.llm.watch_video_track(
1006+
self._video_track, shared_forwarder=processed_forwarder
1007+
)
1008+
self._current_video_track_id = track_id
9981009
else:
9991010
# No video publisher, send raw frames - switch to this new track
10001011
self.logger.info(
10011012
f"🎥 Switching to {track_type_name} track: {track_id}"
10021013
)
1003-
if isinstance(self.llm, Realtime):
1004-
await self.llm._watch_video_track(
1005-
track, shared_forwarder=raw_forwarder
1006-
)
1007-
self._current_video_track_id = track_id
1014+
await self.llm.watch_video_track(
1015+
track, shared_forwarder=raw_forwarder
1016+
)
1017+
self._current_video_track_id = track_id
10081018

10091019
has_image_processors = len(self.image_processors) > 0
10101020

@@ -1096,7 +1106,7 @@ async def recv(self):
10961106
async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None:
10971107
"""Handle turn detection events."""
10981108
# Skip the turn event handling if the model doesn't require TTS or SST audio itself.
1099-
if not (self.llm.needs_tts and self.llm.needs_stt):
1109+
if _is_audio_llm(self.llm):
11001110
return
11011111

11021112
if isinstance(event, TurnStartedEvent):
@@ -1167,7 +1177,7 @@ def publish_audio(self) -> bool:
11671177
Returns:
11681178
True if TTS is configured, when in Realtime mode, or if there are audio publishers.
11691179
"""
1170-
if self.tts is not None or self.llm.handles_audio:
1180+
if self.tts is not None or _is_audio_llm(self.llm):
11711181
return True
11721182
# Also publish audio if there are audio publishers (e.g., HeyGen avatar)
11731183
if self.audio_publishers:
@@ -1204,7 +1214,7 @@ def _needs_audio_or_video_input(self) -> bool:
12041214
# Video input needed for:
12051215
# - Video processors (for frame analysis)
12061216
# - Realtime mode with video (multimodal LLMs)
1207-
needs_video = len(self.video_processors) > 0 or self.llm.handles_video
1217+
needs_video = len(self.video_processors) > 0 or _is_video_llm(self.llm)
12081218

12091219
return needs_audio or needs_video
12101220

@@ -1255,7 +1265,7 @@ def image_processors(self) -> List[Any]:
12551265

12561266
def _validate_configuration(self):
12571267
"""Validate the agent configuration."""
1258-
if self.llm.handles_audio:
1268+
if _is_audio_llm(self.llm):
12591269
# Realtime mode - should not have separate STT/TTS
12601270
if self.stt or self.tts:
12611271
self.logger.warning(
@@ -1292,8 +1302,8 @@ def _prepare_rtc(self):
12921302

12931303
# Set up audio track if TTS is available
12941304
if self.publish_audio:
1295-
if self.llm.handles_audio:
1296-
self._audio_track = self.llm.output_track
1305+
if _is_audio_llm(self.llm):
1306+
self._audio_track = self.llm.output_audio_track
12971307
self.logger.info("🎵 Using Realtime provider output track for audio")
12981308
elif self.audio_publishers:
12991309
# Get the first audio publisher to create the track
Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1-
from .llm import LLM
1+
from .llm import LLM, AudioLLM, VideoLLM, OmniLLM
22
from .realtime import Realtime
33
from .function_registry import FunctionRegistry, function_registry
44

5-
__all__ = ["LLM", "Realtime", "FunctionRegistry", "function_registry"]
5+
__all__ = [
6+
"LLM",
7+
"AudioLLM",
8+
"VideoLLM",
9+
"OmniLLM",
10+
"Realtime",
11+
"FunctionRegistry",
12+
"function_registry",
13+
]

agents-core/vision_agents/core/llm/llm.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Generic,
1616
)
1717

18+
import aiortc
1819
from vision_agents.core.llm import events
1920
from vision_agents.core.llm.events import ToolStartEvent, ToolEndEvent
2021

@@ -23,11 +24,13 @@
2324
from vision_agents.core.agents.conversation import Conversation
2425

2526
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import Participant
27+
from getstream.video.rtc import AudioStreamTrack, PcmData
2628
from vision_agents.core.processors import Processor
2729
from vision_agents.core.utils.utils import parse_instructions
2830
from vision_agents.core.events.manager import EventManager
2931
from .function_registry import FunctionRegistry
3032
from .llm_types import ToolSchema, NormalizedToolCallItem
33+
from ..utils.video_forwarder import VideoForwarder
3134

3235
T = TypeVar("T")
3336

@@ -44,13 +47,6 @@ def __init__(self, original: T, text: str, exception: Optional[Exception] = None
4447

4548

4649
class LLM(abc.ABC):
47-
# Instruct the Agent that this model requires STT and TTS services, and it doesn't handle audio and video
48-
# on its own.
49-
needs_stt: bool = True
50-
needs_tts: bool = True
51-
handles_audio: bool = False
52-
handles_video: bool = False
53-
5450
before_response_listener: BeforeCb
5551
after_response_listener: AfterCb
5652
agent: Optional["Agent"]
@@ -407,3 +403,42 @@ def _sanitize_tool_output(self, value: Any, max_chars: int = 60_000) -> str:
407403
"""
408404
s = value if isinstance(value, str) else json.dumps(value)
409405
return (s[:max_chars] + "…") if len(s) > max_chars else s
406+
407+
408+
class AudioLLM(LLM, metaclass=abc.ABCMeta):
409+
"""
410+
A base class for LLMs capable of processing speech-to-speech audio.
411+
These models do not require TTS and STT services to run.
412+
"""
413+
414+
@abc.abstractmethod
415+
async def simple_audio_response(
416+
self, pcm: PcmData, participant: Optional[Participant] = None
417+
): ...
418+
419+
@property
420+
@abc.abstractmethod
421+
def output_audio_track(self) -> AudioStreamTrack: ...
422+
423+
424+
class VideoLLM(LLM, metaclass=abc.ABCMeta):
425+
"""
426+
A base class for LLMs capable of processing video.
427+
428+
These models will receive the video track from the `Agent` to analyze it.
429+
"""
430+
431+
@abc.abstractmethod
432+
async def watch_video_track(
433+
self,
434+
track: aiortc.mediastreams.MediaStreamTrack,
435+
shared_forwarder: Optional[VideoForwarder] = None,
436+
) -> None: ...
437+
438+
439+
class OmniLLM(AudioLLM, VideoLLM, metaclass=abc.ABCMeta):
440+
"""
441+
A base class for LLMs capable of both video and speech-to-speech audio processing.
442+
"""
443+
444+
...

agents-core/vision_agents/core/llm/realtime.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from __future__ import annotations
22

33
from typing import (
4-
Any,
54
Optional,
65
)
76

8-
from getstream.video.rtc.audio_track import AudioStreamTrack
97
from getstream.video.rtc.track_util import PcmData
108
from vision_agents.core.edge.types import Participant
119

@@ -14,14 +12,13 @@
1412
import logging
1513
import uuid
1614

17-
18-
from . import events, LLM
15+
from . import events, OmniLLM
1916

2017

2118
logger = logging.getLogger(__name__)
2219

2320

24-
class Realtime(LLM, abc.ABC):
21+
class Realtime(OmniLLM):
2522
"""
2623
Realtime is an abstract base class for LLMs that can receive audio and video
2724
@@ -42,13 +39,6 @@ class Realtime(LLM, abc.ABC):
4239
fps: int = 1
4340
session_id: str # UUID to identify this session
4441

45-
# Instruct the Agent that this model can handle audio and video
46-
# without additional STT and TTS services.
47-
handles_audio: bool = True
48-
handles_video: bool = True
49-
needs_stt = False
50-
needs_tts = False
51-
5242
def __init__(
5343
self,
5444
fps: int = 1, # the number of video frames per second to send (for implementations that support setting fps)
@@ -59,10 +49,6 @@ def __init__(
5949
self.provider_name = "realtime_base"
6050
self.session_id = str(uuid.uuid4())
6151
self.fps = fps
62-
# The most common style output track (webrtc)
63-
self.output_track: AudioStreamTrack = AudioStreamTrack(
64-
sample_rate=48000, channels=2, format="s16"
65-
)
6652
# Store current participant for user speech transcription events
6753
self._current_participant: Optional[Participant] = None
6854

@@ -74,10 +60,6 @@ async def simple_audio_response(
7460
self, pcm: PcmData, participant: Optional[Participant] = None
7561
): ...
7662

77-
async def _watch_video_track(self, track: Any, **kwargs) -> None:
78-
"""Optionally overridden by providers that support video input."""
79-
return None
80-
8163
async def _stop_watching_video_track(self) -> None:
8264
"""Optionally overridden by providers that support video input."""
8365
return None

0 commit comments

Comments
 (0)