Skip to content

Commit 2013be5

Browse files
authored
ensure chat works with default types (#99)
1 parent 89d346b commit 2013be5

File tree

1 file changed

+125
-70
lines changed

1 file changed

+125
-70
lines changed

plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py

Lines changed: 125 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
import logging
23
import asyncio
34
import os
@@ -8,12 +9,15 @@
89
import aiortc
910
from getstream import AsyncStream
1011
from getstream.chat.async_client import ChatClient
11-
from getstream.models import ChannelInput
12+
from getstream.models import ChannelInput, ChannelMember
1213
from getstream.video import rtc
1314
from getstream.chat.async_channel import Channel
1415
from getstream.video.async_call import Call
1516
from getstream.video.rtc import ConnectionManager, audio_track
16-
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import Participant, TrackType
17+
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import (
18+
Participant,
19+
TrackType,
20+
)
1721
from getstream.video.rtc.track_util import PcmData
1822
from getstream.video.rtc.tracks import SubscriptionConfig, TrackSubscriptionConfig
1923

@@ -37,11 +41,13 @@ def __init__(self, connection: ConnectionManager):
3741
async def close(self):
3842
await self._connection.leave()
3943

44+
4045
class StreamEdge(EdgeTransport):
4146
"""
4247
StreamEdge uses getstream.io's edge network. To support multiple vendors, this means we expose
4348
4449
"""
50+
4551
client: AsyncStream
4652

4753
def __init__(self, **kwargs):
@@ -71,9 +77,15 @@ def __init__(self, **kwargs):
7177
def _get_webrtc_kind(self, track_type_int: int) -> str:
7278
"""Get the expected WebRTC kind (audio/video) for a SFU track type."""
7379
# Map SFU track types to WebRTC kinds
74-
if track_type_int in (TrackType.TRACK_TYPE_AUDIO, TrackType.TRACK_TYPE_SCREEN_SHARE_AUDIO):
80+
if track_type_int in (
81+
TrackType.TRACK_TYPE_AUDIO,
82+
TrackType.TRACK_TYPE_SCREEN_SHARE_AUDIO,
83+
):
7584
return "audio"
76-
elif track_type_int in (TrackType.TRACK_TYPE_VIDEO, TrackType.TRACK_TYPE_SCREEN_SHARE):
85+
elif track_type_int in (
86+
TrackType.TRACK_TYPE_VIDEO,
87+
TrackType.TRACK_TYPE_SCREEN_SHARE,
88+
):
7789
return "video"
7890
else:
7991
# Default to video for unknown types
@@ -101,7 +113,9 @@ async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
101113
# First check if track already exists in map (e.g., from previous unpublish/republish)
102114
if track_key in self._track_map:
103115
self._track_map[track_key]["published"] = True
104-
self.logger.info(f"Track marked as published (already existed): {track_key}")
116+
self.logger.info(
117+
f"Track marked as published (already existed): {track_key}"
118+
)
105119
return
106120

107121
# Wait for pending track to be populated (with 10 second timeout)
@@ -113,34 +127,42 @@ async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
113127

114128
while elapsed < timeout:
115129
# Find pending track for this user/session with matching kind
116-
for tid, (pending_user, pending_session, pending_kind) in list(self._pending_tracks.items()):
117-
if (pending_user == user_id and
118-
pending_session == session_id and
119-
pending_kind == expected_kind):
130+
for tid, (pending_user, pending_session, pending_kind) in list(
131+
self._pending_tracks.items()
132+
):
133+
if (
134+
pending_user == user_id
135+
and pending_session == session_id
136+
and pending_kind == expected_kind
137+
):
120138
track_id = tid
121139
del self._pending_tracks[tid]
122140
break
123-
141+
124142
if track_id:
125143
break
126-
144+
127145
# Wait a bit before checking again
128146
await asyncio.sleep(poll_interval)
129147
elapsed += poll_interval
130-
148+
131149
if track_id:
132150
# Store with correct type from SFU
133151
self._track_map[track_key] = {"track_id": track_id, "published": True}
134-
self.logger.info(f"Trackmap published: {track_type_int} from {user_id}, track_id: {track_id} (waited {elapsed:.2f}s)")
135-
152+
self.logger.info(
153+
f"Trackmap published: {track_type_int} from {user_id}, track_id: {track_id} (waited {elapsed:.2f}s)"
154+
)
155+
136156
# NOW spawn TrackAddedEvent with correct type
137-
self.events.send(events.TrackAddedEvent(
138-
plugin_name="getstream",
139-
track_id=track_id,
140-
track_type=track_type_int,
141-
user=event.participant,
142-
user_metadata=event.participant
143-
))
157+
self.events.send(
158+
events.TrackAddedEvent(
159+
plugin_name="getstream",
160+
track_id=track_id,
161+
track_type=track_type_int,
162+
user=event.participant,
163+
user_metadata=event.participant,
164+
)
165+
)
144166
else:
145167
raise TimeoutError(
146168
f"Timeout waiting for pending track: {track_type_int} ({expected_kind}) from user {user_id}, "
@@ -149,10 +171,12 @@ async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
149171
f"Key: {track_key}\n"
150172
f"Track map: {self._track_map}\n"
151173
)
152-
153-
async def _on_track_removed(self, event: sfu_events.ParticipantLeftEvent | sfu_events.TrackUnpublishedEvent):
174+
175+
async def _on_track_removed(
176+
self, event: sfu_events.ParticipantLeftEvent | sfu_events.TrackUnpublishedEvent
177+
):
154178
"""Handle track unpublished and participant left events."""
155-
if not event.payload: # NOTE: mypy typecheck
179+
if not event.payload: # NOTE: mypy typecheck
156180
return
157181

158182
participant = event.participant
@@ -164,32 +188,36 @@ async def _on_track_removed(self, event: sfu_events.ParticipantLeftEvent | sfu_e
164188
session_id = event.payload.session_id
165189

166190
# Determine which tracks to remove
167-
if hasattr(event.payload, 'type') and event.payload is not None:
191+
if hasattr(event.payload, "type") and event.payload is not None:
168192
# TrackUnpublishedEvent - single track
169193
tracks_to_remove = [event.payload.type]
170194
event_desc = "Track unpublished"
171195
else:
172196
# ParticipantLeftEvent - all published tracks
173-
tracks_to_remove = (event.participant.published_tracks if event.participant else None) or []
197+
tracks_to_remove = (
198+
event.participant.published_tracks if event.participant else None
199+
) or []
174200
event_desc = "Participant left"
175-
201+
176202
track_names = [TrackType.Name(t) for t in tracks_to_remove]
177203
self.logger.info(f"{event_desc}: {user_id}, tracks: {track_names}")
178-
204+
179205
# Mark each track as unpublished and send TrackRemovedEvent
180206
for track_type_int in tracks_to_remove:
181207
track_key = (user_id, session_id, track_type_int)
182208
track_info = self._track_map.get(track_key)
183209

184210
if track_info:
185211
track_id = track_info["track_id"]
186-
self.events.send(events.TrackRemovedEvent(
187-
plugin_name="getstream",
188-
track_id=track_id,
189-
track_type=track_type_int,
190-
user=participant,
191-
user_metadata=participant
192-
))
212+
self.events.send(
213+
events.TrackRemovedEvent(
214+
plugin_name="getstream",
215+
track_id=track_id,
216+
track_type=track_type_int,
217+
user=participant,
218+
user_metadata=participant,
219+
)
220+
)
193221
# Mark as unpublished instead of removing
194222
self._track_map[track_key]["published"] = False
195223
else:
@@ -240,31 +268,42 @@ async def join(self, agent: "Agent", call: Call) -> StreamConnection:
240268
async def on_track(track_id, track_type, user):
241269
# Store track in pending map - wait for SFU to confirm type before spawning TrackAddedEvent
242270
self._pending_tracks[track_id] = (user.user_id, user.session_id, track_type)
243-
self.logger.info(f"Track received from WebRTC (pending SFU confirmation): {track_id}, type: {track_type}, user: {user.user_id}")
271+
self.logger.info(
272+
f"Track received from WebRTC (pending SFU confirmation): {track_id}, type: {track_type}, user: {user.user_id}"
273+
)
244274

245275
self.events.silent(events.AudioReceivedEvent)
276+
246277
@connection.on("audio")
247278
async def on_audio_received(pcm: PcmData, participant: Participant):
248-
self.events.send(events.AudioReceivedEvent(
249-
plugin_name="getstream",
250-
pcm_data=pcm,
251-
participant=participant,
252-
user_metadata=participant
253-
))
254-
255-
await connection.__aenter__() # TODO: weird API? there should be a manual version
279+
self.events.send(
280+
events.AudioReceivedEvent(
281+
plugin_name="getstream",
282+
pcm_data=pcm,
283+
participant=participant,
284+
user_metadata=participant,
285+
)
286+
)
287+
288+
await (
289+
connection.__aenter__()
290+
) # TODO: weird API? there should be a manual version
256291
self._connection = connection
257292

258293
standardize_connection = StreamConnection(connection)
259294
return standardize_connection
260295

261296
def create_audio_track(self, framerate: int = 48000, stereo: bool = True):
262-
return audio_track.AudioStreamTrack(framerate=framerate, stereo=stereo) # default to webrtc framerate
297+
return audio_track.AudioStreamTrack(
298+
framerate=framerate, stereo=stereo
299+
) # default to webrtc framerate
263300

264301
def create_video_track(self):
265302
return aiortc.VideoStreamTrack()
266303

267-
def add_track_subscriber(self, track_id: str) -> Optional[aiortc.mediastreams.MediaStreamTrack]:
304+
def add_track_subscriber(
305+
self, track_id: str
306+
) -> Optional[aiortc.mediastreams.MediaStreamTrack]:
268307
return self._connection.subscriber_pc.add_track_subscriber(track_id)
269308

270309
async def publish_tracks(self, audio_track, video_track):
@@ -301,6 +340,45 @@ async def open_demo(self, call: Call) -> str:
301340

302341
# Create the user in the GetStream system
303342
await client.create_user(name=name, id=human_id)
343+
344+
# Ensure that both agent and user get access the demo by adding the user as member and the agent the channel creator
345+
channel = client.chat.channel(self.channel_type, call.id)
346+
response = await channel.get_or_create(
347+
data=ChannelInput(
348+
created_by_id=self.agent_user_id,
349+
members=[
350+
ChannelMember(
351+
user_id=human_id,
352+
# TODO: get rid of this when codegen for stream-py is fixed, these fields are meaningless
353+
banned=False,
354+
channel_role="",
355+
created_at=datetime.datetime.now(datetime.UTC),
356+
notifications_muted=False,
357+
shadow_banned=False,
358+
updated_at=datetime.datetime.now(datetime.UTC),
359+
custom={},
360+
)
361+
],
362+
)
363+
)
364+
365+
if human_id not in [m.user_id for m in response.data.members]:
366+
await channel.update(
367+
add_members=[
368+
ChannelMember(
369+
user_id=human_id,
370+
# TODO: get rid of this when codegen for stream-py is fixed, these fields are meaningless
371+
banned=False,
372+
channel_role="",
373+
created_at=datetime.datetime.now(datetime.UTC),
374+
notifications_muted=False,
375+
shadow_banned=False,
376+
updated_at=datetime.datetime.now(datetime.UTC),
377+
custom={},
378+
)
379+
]
380+
)
381+
304382
# Create user token for browser access
305383
token = client.create_token(human_id, expiration=3600)
306384

@@ -317,7 +395,7 @@ async def open_demo(self, call: Call) -> str:
317395
"bitrate": 12000000,
318396
"w": 1920,
319397
"h": 1080,
320-
# TODO: FPS..., aim at 60fps
398+
"channel_type": self.channel_type,
321399
}
322400

323401
url = f"{base_url}{call.id}?{urlencode(params)}"
@@ -331,26 +409,3 @@ async def open_demo(self, call: Call) -> str:
331409
print(f"Please manually open this URL: {url}")
332410

333411
return url
334-
335-
def open_pronto(self, api_key: str, token: str, call_id: str):
336-
"""Open browser with the video call URL."""
337-
# Use the same URL pattern as the working workout assistant example
338-
base_url = (
339-
f"{os.getenv('EXAMPLE_BASE_URL', 'https://pronto-staging.getstream.io')}/join/"
340-
)
341-
params = {
342-
"api_key": api_key,
343-
"token": token,
344-
"skip_lobby": "true",
345-
"video_encoder": "vp8",
346-
}
347-
348-
url = f"{base_url}{call_id}?{urlencode(params)}"
349-
self.logger.info(f"🌐 Opening browser: {url}")
350-
351-
try:
352-
webbrowser.open(url)
353-
self.logger.info("✅ Browser opened successfully!")
354-
except Exception as e:
355-
self.logger.error(f"❌ Failed to open browser: {e}")
356-
self.logger.info(f"Please manually open this URL: {url}")

0 commit comments

Comments
 (0)