|
| 1 | +import base64 |
| 2 | +import io |
| 3 | +import logging |
| 4 | +import os |
| 5 | +from collections import deque |
| 6 | +from typing import Iterator, Optional |
| 7 | + |
| 8 | +import aiortc |
| 9 | +import av |
| 10 | +from PIL.Image import Resampling |
| 11 | +from openai import AsyncOpenAI |
| 12 | + |
| 13 | +from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import Participant |
| 14 | + |
| 15 | +from vision_agents.core.llm.llm import LLMResponseEvent, VideoLLM |
| 16 | +from vision_agents.core.llm.events import ( |
| 17 | + LLMResponseChunkEvent, |
| 18 | + LLMResponseCompletedEvent, |
| 19 | +) |
| 20 | +from vision_agents.core.utils.video_forwarder import VideoForwarder |
| 21 | +from . import events |
| 22 | + |
| 23 | +from vision_agents.core.processors import Processor |
| 24 | + |
| 25 | + |
| 26 | +logger = logging.getLogger(__name__) |
| 27 | + |
| 28 | + |
| 29 | +PLUGIN_NAME = "baseten_vlm" |
| 30 | + |
| 31 | + |
| 32 | +class BasetenVLM(VideoLLM): |
| 33 | + """ |
| 34 | + TODO: Docs |
| 35 | + TODO: Tool calling support? |
| 36 | +
|
| 37 | + Examples: |
| 38 | +
|
| 39 | + from vision_agents.plugins import baseten |
| 40 | + llm = baseten.VLM(model="qwen3vl") |
| 41 | +
|
| 42 | + """ |
| 43 | + |
| 44 | + def __init__( |
| 45 | + self, |
| 46 | + model: str, |
| 47 | + api_key: Optional[str] = None, |
| 48 | + base_url: Optional[str] = None, |
| 49 | + fps: int = 1, |
| 50 | + frame_buffer_seconds: int = 10, |
| 51 | + client: Optional[AsyncOpenAI] = None, |
| 52 | + ): |
| 53 | + """ |
| 54 | + Initialize the BasetenVLM class. |
| 55 | +
|
| 56 | + Args: |
| 57 | + model (str): The Baseten-hosted model to use. |
| 58 | + api_key: optional API key. By default, loads from BASETEN_API_KEY environment variable. |
| 59 | + base_url: optional base url. By default, loads from BASETEN_BASE_URL environment variable. |
| 60 | + fps: the number of video frames per second to handle. |
| 61 | + frame_buffer_seconds: the number of seconds to buffer for the model's input. |
| 62 | + Total buffer size = fps * frame_buffer_seconds. |
| 63 | + client: optional `AsyncOpenAI` client. By default, creates a new client object. |
| 64 | + """ |
| 65 | + super().__init__() |
| 66 | + self.model = model |
| 67 | + self.events.register_events_from_module(events) |
| 68 | + |
| 69 | + api_key = api_key or os.getenv("BASETEN_API_KEY") |
| 70 | + base_url = base_url or os.getenv("BASETEN_BASE_URL") |
| 71 | + if client is not None: |
| 72 | + self._client = client |
| 73 | + elif not api_key: |
| 74 | + raise ValueError("api_key must be provided") |
| 75 | + elif not base_url: |
| 76 | + raise ValueError("base_url must be provided") |
| 77 | + else: |
| 78 | + self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) |
| 79 | + |
| 80 | + self._fps = fps |
| 81 | + self._video_forwarder: Optional[VideoForwarder] = None |
| 82 | + |
| 83 | + # Buffer latest 10s of the video track to forward it to the model |
| 84 | + # together with the user transcripts |
| 85 | + self._frame_buffer: deque[av.VideoFrame] = deque( |
| 86 | + maxlen=fps * frame_buffer_seconds |
| 87 | + ) |
| 88 | + self._frame_width = 800 |
| 89 | + self._frame_height = 600 |
| 90 | + |
| 91 | + async def simple_response( |
| 92 | + self, |
| 93 | + text: str, |
| 94 | + processors: Optional[list[Processor]] = None, |
| 95 | + participant: Optional[Participant] = None, |
| 96 | + ) -> LLMResponseEvent: |
| 97 | + """ |
| 98 | + simple_response is a standardized way to create an LLM response. |
| 99 | +
|
| 100 | + This method is also called every time the new STT transcript is received. |
| 101 | +
|
| 102 | + Args: |
| 103 | + text: The text to respond to. |
| 104 | + processors: list of processors (which contain state) about the video/voice AI. |
| 105 | + participant: the Participant object, optional. |
| 106 | +
|
| 107 | + Examples: |
| 108 | +
|
| 109 | + llm.simple_response("say hi to the user, be nice") |
| 110 | + """ |
| 111 | + |
| 112 | + # TODO: Clean up the `_build_enhanced_instructions` and use that. The should be compiled at the agent probably. |
| 113 | + |
| 114 | + if self._conversation is None: |
| 115 | + # The agent hasn't joined the call yet. |
| 116 | + logger.warning( |
| 117 | + "Cannot create an LLM response - the conversation has not been initialized yet." |
| 118 | + ) |
| 119 | + return LLMResponseEvent(original=None, text="") |
| 120 | + |
| 121 | + messages = [] |
| 122 | + |
| 123 | + # Add Agent's instructions as system prompt. |
| 124 | + if self.instructions: |
| 125 | + messages.append( |
| 126 | + { |
| 127 | + "role": "system", |
| 128 | + "content": self.instructions, |
| 129 | + } |
| 130 | + ) |
| 131 | + |
| 132 | + # TODO: Do we need to limit how many messages we send? |
| 133 | + # Add all messages from the conversation to the prompt |
| 134 | + for message in self._conversation.messages: |
| 135 | + messages.append( |
| 136 | + { |
| 137 | + "role": message.role, |
| 138 | + "content": message.content, |
| 139 | + } |
| 140 | + ) |
| 141 | + |
| 142 | + # Attach the latest bufferred frames to the request |
| 143 | + frames_data = [] |
| 144 | + for frame_bytes in self._get_frames_bytes(): |
| 145 | + frame_b64 = base64.b64encode(frame_bytes).decode("utf-8") |
| 146 | + frame_msg = { |
| 147 | + "type": "image_url", |
| 148 | + "image_url": {"url": f"data:image/jpeg;base64,{frame_b64}"}, |
| 149 | + } |
| 150 | + frames_data.append(frame_msg) |
| 151 | + |
| 152 | + logger.debug( |
| 153 | + f'Forwarding {len(frames_data)} to the Baseten model "{self.model}"' |
| 154 | + ) |
| 155 | + |
| 156 | + messages.append( |
| 157 | + { |
| 158 | + "role": "user", |
| 159 | + "content": frames_data, |
| 160 | + } |
| 161 | + ) |
| 162 | + |
| 163 | + # TODO: Maybe move it to a method, too much code |
| 164 | + try: |
| 165 | + response = await self._client.chat.completions.create( |
| 166 | + messages=messages, model=self.model, stream=True |
| 167 | + ) |
| 168 | + except Exception as e: |
| 169 | + # Send an error event if the request failed |
| 170 | + logger.exception( |
| 171 | + f'Failed to get a response from the Baseten model "{self.model}"' |
| 172 | + ) |
| 173 | + self.events.send( |
| 174 | + events.LLMErrorEvent( |
| 175 | + plugin_name=PLUGIN_NAME, |
| 176 | + error_message=str(e), |
| 177 | + event_data=e, |
| 178 | + ) |
| 179 | + ) |
| 180 | + return LLMResponseEvent(original=None, text="") |
| 181 | + |
| 182 | + i = 0 |
| 183 | + llm_response_event: Optional[LLMResponseEvent] = LLMResponseEvent( |
| 184 | + original=None, text="" |
| 185 | + ) |
| 186 | + text_chunks: list[str] = [] |
| 187 | + total_text = "" |
| 188 | + async for chunk in response: |
| 189 | + if not chunk.choices: |
| 190 | + continue |
| 191 | + |
| 192 | + choice = chunk.choices[0] |
| 193 | + content = choice.delta.content |
| 194 | + finish_reason = choice.finish_reason |
| 195 | + |
| 196 | + if content: |
| 197 | + text_chunks.append(content) |
| 198 | + # Emit delta events for each response chunk. |
| 199 | + self.events.send( |
| 200 | + LLMResponseChunkEvent( |
| 201 | + plugin_name=PLUGIN_NAME, |
| 202 | + content_index=None, |
| 203 | + item_id=chunk.id, |
| 204 | + output_index=0, |
| 205 | + sequence_number=i, |
| 206 | + delta=content, |
| 207 | + ) |
| 208 | + ) |
| 209 | + |
| 210 | + elif finish_reason: |
| 211 | + # Emit the completion event when the response stream is finished. |
| 212 | + total_text = "".join(text_chunks) |
| 213 | + self.events.send( |
| 214 | + LLMResponseCompletedEvent( |
| 215 | + plugin_name=PLUGIN_NAME, |
| 216 | + original=chunk, |
| 217 | + text=total_text, |
| 218 | + item_id=chunk.id, |
| 219 | + ) |
| 220 | + ) |
| 221 | + |
| 222 | + llm_response_event = LLMResponseEvent(original=chunk, text=total_text) |
| 223 | + i += 1 |
| 224 | + |
| 225 | + return llm_response_event |
| 226 | + |
| 227 | + async def watch_video_track( |
| 228 | + self, |
| 229 | + track: aiortc.mediastreams.VideoStreamTrack, # TODO: Check if this works, maybe I need to update typings everywhere |
| 230 | + shared_forwarder: Optional[VideoForwarder] = None, |
| 231 | + ) -> None: |
| 232 | + """ |
| 233 | + Setup video forwarding and start bufferring video frames. |
| 234 | + This method is called by the `Agent`. |
| 235 | +
|
| 236 | + Args: |
| 237 | + track: instance of VideoStreamTrack. |
| 238 | + shared_forwarder: a shared VideoForwarder instance if present. Defaults to None. |
| 239 | +
|
| 240 | + Returns: None |
| 241 | + """ |
| 242 | + |
| 243 | + if self._video_forwarder is not None and shared_forwarder is None: |
| 244 | + logger.warning("Video forwarder already running, stopping the previous one") |
| 245 | + await self._video_forwarder.stop() |
| 246 | + self._video_forwarder = None |
| 247 | + logger.info("Stopped video forwarding") |
| 248 | + |
| 249 | + logger.info("🎥 BasetenVLM subscribing to VideoForwarder") |
| 250 | + if not shared_forwarder: |
| 251 | + self._video_forwarder = shared_forwarder or VideoForwarder( |
| 252 | + track, |
| 253 | + max_buffer=10, |
| 254 | + fps=1.0, # Low FPS for VLM |
| 255 | + name="baseten_vlm_forwarder", |
| 256 | + ) |
| 257 | + await self._video_forwarder.start() |
| 258 | + else: |
| 259 | + self._video_forwarder = shared_forwarder |
| 260 | + |
| 261 | + # Start buffering video frames |
| 262 | + await self._video_forwarder.start_event_consumer(self._frame_buffer.append) |
| 263 | + |
| 264 | + def _get_frames_bytes(self) -> Iterator[bytes]: |
| 265 | + """ |
| 266 | + Iterate over all bufferred video frames. |
| 267 | + """ |
| 268 | + for frame in self._frame_buffer: |
| 269 | + yield _frame_to_jpeg_bytes( |
| 270 | + frame=frame, |
| 271 | + target_width=self._frame_width, |
| 272 | + target_height=self._frame_height, |
| 273 | + quality=85, |
| 274 | + ) |
| 275 | + |
| 276 | + |
| 277 | +# TODO: Move it to some core utils |
| 278 | +def _frame_to_jpeg_bytes( |
| 279 | + frame: av.VideoFrame, target_width: int, target_height: int, quality: int = 85 |
| 280 | +) -> bytes: |
| 281 | + """ |
| 282 | + Convert a frame to JPEG bytes with resizing. |
| 283 | +
|
| 284 | + Args: |
| 285 | + frame: an instance of `av.VideoFrame` |
| 286 | + target_width: target width in pixels |
| 287 | + target_height: target height in pixels |
| 288 | + quality: JPEG quality. Default is 85. |
| 289 | +
|
| 290 | + Returns: frame as JPEG bytes. |
| 291 | +
|
| 292 | + """ |
| 293 | + # Convert frame to a PIL image |
| 294 | + img = frame.to_image() |
| 295 | + |
| 296 | + # Calculate scaling to maintain aspect ratio |
| 297 | + src_width, src_height = img.size |
| 298 | + # Calculate scale factor (fit within target dimensions) |
| 299 | + scale = min(target_width / src_width, target_height / src_height) |
| 300 | + new_width = int(src_width * scale) |
| 301 | + new_height = int(src_height * scale) |
| 302 | + |
| 303 | + # Resize with aspect ratio maintained |
| 304 | + resized = img.resize((new_width, new_height), Resampling.LANCZOS) |
| 305 | + |
| 306 | + # Save as JPEG with quality control |
| 307 | + buf = io.BytesIO() |
| 308 | + resized.save(buf, "JPEG", quality=quality, optimize=True) |
| 309 | + return buf.getvalue() |
0 commit comments