Skip to content

Commit 9211f4c

Browse files
DeanChensjcopybara-github
authored andcommitted
fix: Use async for to loop through event iterator to get all events in vertex_ai_session_service
Fix #3559 Co-authored-by: Shangjie Chen <deanchen@google.com> PiperOrigin-RevId: 832476367
1 parent a754c96 commit 9211f4c

File tree

2 files changed

+61
-4
lines changed

2 files changed

+61
-4
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ async def get_session(
178178
)
179179
session.events += [
180180
_from_api_event(event)
181-
for event in events_iterator
181+
async for event in events_iterator
182182
if event.timestamp.timestamp() <= update_timestamp
183183
]
184184

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
import copy
15+
import datetime
1616
import re
1717
import types
1818
from typing import Any
@@ -130,6 +130,36 @@
130130
'user_id': 'user_with_pages',
131131
}
132132

133+
MOCK_SESSION_JSON_5 = {
134+
'name': (
135+
'projects/test-project/locations/test-location/'
136+
'reasoningEngines/123/sessions/5'
137+
),
138+
'update_time': '2024-12-12T12:15:12.123456Z',
139+
'user_id': 'user_with_many_events',
140+
}
141+
142+
143+
def _generate_mock_events_for_session_5(num_events):
144+
events = []
145+
start_time = isoparse('2024-12-12T12:12:12.123456Z')
146+
for i in range(num_events):
147+
event_time = start_time + datetime.timedelta(microseconds=i * 1000)
148+
events.append({
149+
'name': (
150+
'projects/test-project/locations/test-location/'
151+
f'reasoningEngines/123/sessions/5/events/{i}'
152+
),
153+
'invocation_id': f'invocation_{i}',
154+
'author': 'user_with_many_events',
155+
'timestamp': event_time.isoformat().replace('+00:00', 'Z'),
156+
})
157+
return events
158+
159+
160+
MANY_EVENTS_COUNT = 200
161+
MOCK_EVENTS_JSON_5 = _generate_mock_events_for_session_5(MANY_EVENTS_COUNT)
162+
133163
MOCK_SESSION = Session(
134164
app_name='123',
135165
user_id='user',
@@ -228,6 +258,11 @@ def _convert_to_object(data):
228258
return data
229259

230260

261+
async def to_async_iterator(data):
262+
for item in data:
263+
yield item
264+
265+
231266
class MockAsyncClient:
232267
"""Mocks the API Client."""
233268

@@ -330,7 +365,7 @@ async def _list_events(self, name: str, **kwargs):
330365
for event in events
331366
if isoparse(event['timestamp']) >= after_timestamp
332367
]
333-
return [_convert_to_object(event) for event in events]
368+
return to_async_iterator([_convert_to_object(event) for event in events])
334369

335370
async def _append_event(
336371
self,
@@ -496,6 +531,22 @@ async def test_get_session_with_after_timestamp_filter():
496531
assert session.events[0].id == '456'
497532

498533

534+
@pytest.mark.asyncio
535+
@pytest.mark.usefixtures('mock_get_api_client')
536+
async def test_get_session_with_many_events(mock_api_client_instance):
537+
mock_api_client_instance.session_dict['5'] = MOCK_SESSION_JSON_5
538+
mock_api_client_instance.event_dict['5'] = (
539+
copy.deepcopy(MOCK_EVENTS_JSON_5),
540+
None,
541+
)
542+
session_service = mock_vertex_ai_session_service()
543+
session = await session_service.get_session(
544+
app_name='123', user_id='user_with_many_events', session_id='5'
545+
)
546+
assert session is not None
547+
assert len(session.events) == MANY_EVENTS_COUNT
548+
549+
499550
@pytest.mark.asyncio
500551
@pytest.mark.usefixtures('mock_get_api_client')
501552
async def test_list_sessions():
@@ -524,7 +575,13 @@ async def test_list_sessions_all_users():
524575
session_service = mock_vertex_ai_session_service()
525576
sessions = await session_service.list_sessions(app_name='123', user_id=None)
526577
assert len(sessions.sessions) == 5
527-
assert {s.id for s in sessions.sessions} == {'1', '2', '3', 'page1', 'page2'}
578+
assert {s.id for s in sessions.sessions} == {
579+
'1',
580+
'2',
581+
'3',
582+
'page1',
583+
'page2',
584+
}
528585

529586

530587
@pytest.mark.asyncio

0 commit comments

Comments
 (0)