Skip to content

Commit 16c53c8

Browse files
committed
Support for video LLMs hosted on Baseten
1 parent 3e3ce4b commit 16c53c8

File tree

9 files changed

+420
-0
lines changed

9 files changed

+420
-0
lines changed

.env.example

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@ CARTESIA_API_KEY=your_cartesia_api_key_here
2525

2626
# Anthropic API credentials
2727
ANTHROPIC_API_KEY=your_anthropic_api_key_here
28+
29+
# Baseten API credentials
30+
BASETEN_API_KEY=your_baseten_api_key_here
31+
BASETEN_BASE_URL=your_baseten_base_url_here

plugins/baseten/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Baseten Plugin for Vision Agents
2+
3+
LLM integrations for the models hosted on Baseten for Vision Agents framework.
4+
5+
TODO
6+
7+
## Installation
8+
9+
```bash
10+
pip install vision-agents-plugins-baseten
11+
```
12+
13+
## Usage
14+
15+
```python
16+
17+
```
18+
19+
20+
## Requirements
21+
- Python 3.10+
22+
- `openai`
23+
- GetStream SDK
24+
25+
## License
26+
MIT

plugins/baseten/py.typed

Whitespace-only changes.

plugins/baseten/pyproject.toml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
[build-system]
2+
requires = ["hatchling", "hatch-vcs"]
3+
build-backend = "hatchling.build"
4+
5+
[project]
6+
name = "vision-agents-plugins-baseten"
7+
dynamic = ["version"]
8+
description = "Baseten plugin for vision agents"
9+
readme = "README.md"
10+
requires-python = ">=3.10"
11+
license = "MIT"
12+
dependencies = [
13+
"vision-agents",
14+
"openai>=2.5.0",
15+
]
16+
17+
[project.urls]
18+
Documentation = "https://visionagents.ai/"
19+
Website = "https://visionagents.ai/"
20+
Source = "https://github.com/GetStream/Vision-Agents"
21+
22+
[tool.hatch.version]
23+
source = "vcs"
24+
raw-options = { root = "..", search_parent_directories = true, fallback_version = "0.0.0" }
25+
26+
[tool.hatch.build.targets.wheel]
27+
packages = ["."]
28+
29+
[tool.uv.sources]
30+
vision-agents = { workspace = true }
31+
32+
[dependency-groups]
33+
dev = [
34+
"pytest>=8.4.1",
35+
"pytest-asyncio>=1.0.0",
36+
]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .baseten_vlm import BasetenVLM as VLM
2+
3+
4+
__all__ = ["VLM"]
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
import base64
2+
import io
3+
import logging
4+
import os
5+
from collections import deque
6+
from typing import Iterator, Optional
7+
8+
import aiortc
9+
import av
10+
from PIL.Image import Resampling
11+
from openai import AsyncOpenAI
12+
13+
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import Participant
14+
15+
from vision_agents.core.llm.llm import LLMResponseEvent, VideoLLM
16+
from vision_agents.core.llm.events import (
17+
LLMResponseChunkEvent,
18+
LLMResponseCompletedEvent,
19+
)
20+
from vision_agents.core.utils.video_forwarder import VideoForwarder
21+
from . import events
22+
23+
from vision_agents.core.processors import Processor
24+
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
PLUGIN_NAME = "baseten_vlm"
30+
31+
32+
class BasetenVLM(VideoLLM):
33+
"""
34+
TODO: Docs
35+
TODO: Tool calling support?
36+
37+
Examples:
38+
39+
from vision_agents.plugins import baseten
40+
llm = baseten.VLM(model="qwen3vl")
41+
42+
"""
43+
44+
def __init__(
45+
self,
46+
model: str,
47+
api_key: Optional[str] = None,
48+
base_url: Optional[str] = None,
49+
fps: int = 1,
50+
frame_buffer_seconds: int = 10,
51+
client: Optional[AsyncOpenAI] = None,
52+
):
53+
"""
54+
Initialize the BasetenVLM class.
55+
56+
Args:
57+
model (str): The Baseten-hosted model to use.
58+
api_key: optional API key. By default, loads from BASETEN_API_KEY environment variable.
59+
base_url: optional base url. By default, loads from BASETEN_BASE_URL environment variable.
60+
fps: the number of video frames per second to handle.
61+
frame_buffer_seconds: the number of seconds to buffer for the model's input.
62+
Total buffer size = fps * frame_buffer_seconds.
63+
client: optional `AsyncOpenAI` client. By default, creates a new client object.
64+
"""
65+
super().__init__()
66+
self.model = model
67+
self.events.register_events_from_module(events)
68+
69+
api_key = api_key or os.getenv("BASETEN_API_KEY")
70+
base_url = base_url or os.getenv("BASETEN_BASE_URL")
71+
if client is not None:
72+
self._client = client
73+
elif not api_key:
74+
raise ValueError("api_key must be provided")
75+
elif not base_url:
76+
raise ValueError("base_url must be provided")
77+
else:
78+
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
79+
80+
self._fps = fps
81+
self._video_forwarder: Optional[VideoForwarder] = None
82+
83+
# Buffer latest 10s of the video track to forward it to the model
84+
# together with the user transcripts
85+
self._frame_buffer: deque[av.VideoFrame] = deque(
86+
maxlen=fps * frame_buffer_seconds
87+
)
88+
self._frame_width = 800
89+
self._frame_height = 600
90+
91+
async def simple_response(
92+
self,
93+
text: str,
94+
processors: Optional[list[Processor]] = None,
95+
participant: Optional[Participant] = None,
96+
) -> LLMResponseEvent:
97+
"""
98+
simple_response is a standardized way to create an LLM response.
99+
100+
This method is also called every time the new STT transcript is received.
101+
102+
Args:
103+
text: The text to respond to.
104+
processors: list of processors (which contain state) about the video/voice AI.
105+
participant: the Participant object, optional.
106+
107+
Examples:
108+
109+
llm.simple_response("say hi to the user, be nice")
110+
"""
111+
112+
# TODO: Clean up the `_build_enhanced_instructions` and use that. The should be compiled at the agent probably.
113+
114+
if self._conversation is None:
115+
# The agent hasn't joined the call yet.
116+
logger.warning(
117+
"Cannot create an LLM response - the conversation has not been initialized yet."
118+
)
119+
return LLMResponseEvent(original=None, text="")
120+
121+
messages = []
122+
123+
# Add Agent's instructions as system prompt.
124+
if self.instructions:
125+
messages.append(
126+
{
127+
"role": "system",
128+
"content": self.instructions,
129+
}
130+
)
131+
132+
# TODO: Do we need to limit how many messages we send?
133+
# Add all messages from the conversation to the prompt
134+
for message in self._conversation.messages:
135+
messages.append(
136+
{
137+
"role": message.role,
138+
"content": message.content,
139+
}
140+
)
141+
142+
# Attach the latest bufferred frames to the request
143+
frames_data = []
144+
for frame_bytes in self._get_frames_bytes():
145+
frame_b64 = base64.b64encode(frame_bytes).decode("utf-8")
146+
frame_msg = {
147+
"type": "image_url",
148+
"image_url": {"url": f"data:image/jpeg;base64,{frame_b64}"},
149+
}
150+
frames_data.append(frame_msg)
151+
152+
logger.debug(
153+
f'Forwarding {len(frames_data)} to the Baseten model "{self.model}"'
154+
)
155+
156+
messages.append(
157+
{
158+
"role": "user",
159+
"content": frames_data,
160+
}
161+
)
162+
163+
# TODO: Maybe move it to a method, too much code
164+
try:
165+
response = await self._client.chat.completions.create(
166+
messages=messages, model=self.model, stream=True
167+
)
168+
except Exception as e:
169+
# Send an error event if the request failed
170+
logger.exception(
171+
f'Failed to get a response from the Baseten model "{self.model}"'
172+
)
173+
self.events.send(
174+
events.LLMErrorEvent(
175+
plugin_name=PLUGIN_NAME,
176+
error_message=str(e),
177+
event_data=e,
178+
)
179+
)
180+
return LLMResponseEvent(original=None, text="")
181+
182+
i = 0
183+
llm_response_event: Optional[LLMResponseEvent] = LLMResponseEvent(
184+
original=None, text=""
185+
)
186+
text_chunks: list[str] = []
187+
total_text = ""
188+
async for chunk in response:
189+
if not chunk.choices:
190+
continue
191+
192+
choice = chunk.choices[0]
193+
content = choice.delta.content
194+
finish_reason = choice.finish_reason
195+
196+
if content:
197+
text_chunks.append(content)
198+
# Emit delta events for each response chunk.
199+
self.events.send(
200+
LLMResponseChunkEvent(
201+
plugin_name=PLUGIN_NAME,
202+
content_index=None,
203+
item_id=chunk.id,
204+
output_index=0,
205+
sequence_number=i,
206+
delta=content,
207+
)
208+
)
209+
210+
elif finish_reason:
211+
# Emit the completion event when the response stream is finished.
212+
total_text = "".join(text_chunks)
213+
self.events.send(
214+
LLMResponseCompletedEvent(
215+
plugin_name=PLUGIN_NAME,
216+
original=chunk,
217+
text=total_text,
218+
item_id=chunk.id,
219+
)
220+
)
221+
222+
llm_response_event = LLMResponseEvent(original=chunk, text=total_text)
223+
i += 1
224+
225+
return llm_response_event
226+
227+
async def watch_video_track(
228+
self,
229+
track: aiortc.mediastreams.VideoStreamTrack, # TODO: Check if this works, maybe I need to update typings everywhere
230+
shared_forwarder: Optional[VideoForwarder] = None,
231+
) -> None:
232+
"""
233+
Setup video forwarding and start bufferring video frames.
234+
This method is called by the `Agent`.
235+
236+
Args:
237+
track: instance of VideoStreamTrack.
238+
shared_forwarder: a shared VideoForwarder instance if present. Defaults to None.
239+
240+
Returns: None
241+
"""
242+
243+
if self._video_forwarder is not None and shared_forwarder is None:
244+
logger.warning("Video forwarder already running, stopping the previous one")
245+
await self._video_forwarder.stop()
246+
self._video_forwarder = None
247+
logger.info("Stopped video forwarding")
248+
249+
logger.info("🎥 BasetenVLM subscribing to VideoForwarder")
250+
if not shared_forwarder:
251+
self._video_forwarder = shared_forwarder or VideoForwarder(
252+
track,
253+
max_buffer=10,
254+
fps=1.0, # Low FPS for VLM
255+
name="baseten_vlm_forwarder",
256+
)
257+
await self._video_forwarder.start()
258+
else:
259+
self._video_forwarder = shared_forwarder
260+
261+
# Start buffering video frames
262+
await self._video_forwarder.start_event_consumer(self._frame_buffer.append)
263+
264+
def _get_frames_bytes(self) -> Iterator[bytes]:
265+
"""
266+
Iterate over all bufferred video frames.
267+
"""
268+
for frame in self._frame_buffer:
269+
yield _frame_to_jpeg_bytes(
270+
frame=frame,
271+
target_width=self._frame_width,
272+
target_height=self._frame_height,
273+
quality=85,
274+
)
275+
276+
277+
# TODO: Move it to some core utils
278+
def _frame_to_jpeg_bytes(
279+
frame: av.VideoFrame, target_width: int, target_height: int, quality: int = 85
280+
) -> bytes:
281+
"""
282+
Convert a frame to JPEG bytes with resizing.
283+
284+
Args:
285+
frame: an instance of `av.VideoFrame`
286+
target_width: target width in pixels
287+
target_height: target height in pixels
288+
quality: JPEG quality. Default is 85.
289+
290+
Returns: frame as JPEG bytes.
291+
292+
"""
293+
# Convert frame to a PIL image
294+
img = frame.to_image()
295+
296+
# Calculate scaling to maintain aspect ratio
297+
src_width, src_height = img.size
298+
# Calculate scale factor (fit within target dimensions)
299+
scale = min(target_width / src_width, target_height / src_height)
300+
new_width = int(src_width * scale)
301+
new_height = int(src_height * scale)
302+
303+
# Resize with aspect ratio maintained
304+
resized = img.resize((new_width, new_height), Resampling.LANCZOS)
305+
306+
# Save as JPEG with quality control
307+
buf = io.BytesIO()
308+
resized.save(buf, "JPEG", quality=quality, optimize=True)
309+
return buf.getvalue()
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from dataclasses import dataclass, field
2+
from vision_agents.core.events import PluginBaseEvent
3+
from typing import Optional, Any
4+
5+
6+
@dataclass
7+
class LLMErrorEvent(PluginBaseEvent):
8+
"""Event emitted when an LLM encounters an error."""
9+
10+
type: str = field(default="plugin.llm.error", init=False)
11+
error_message: Optional[str] = None
12+
event_data: Optional[Any] = None

0 commit comments

Comments
 (0)