Skip to content

Commit e63fe0c

Browse files
DeanChensjcopybara-github
authored andcommitted
fix: Fix pagination of list_sessions in VertexAiSessionService
Resolves #2860 PiperOrigin-RevId: 804511401
1 parent bc6b546 commit e63fe0c

File tree

2 files changed

+88
-32
lines changed

2 files changed

+88
-32
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -275,36 +275,47 @@ async def list_sessions(
275275
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
276276
api_client = self._get_api_client()
277277

278-
path = f'reasoningEngines/{reasoning_engine_id}/sessions'
279-
if user_id:
280-
parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='')
281-
path = path + f'?filter=user_id={parsed_user_id}'
278+
base_path = f'reasoningEngines/{reasoning_engine_id}/sessions'
279+
sessions = []
280+
page_token = None
281+
while True:
282+
path = base_path
283+
query_params = {}
284+
if user_id:
285+
query_params['filter'] = f'user_id="{user_id}"'
286+
if page_token:
287+
query_params['pageToken'] = page_token
288+
289+
if query_params:
290+
path = f'{path}?{urllib.parse.urlencode(query_params)}'
291+
292+
list_sessions_api_response = await api_client.async_request(
293+
http_method='GET',
294+
path=path,
295+
request_dict={},
296+
)
297+
converted_api_response = _convert_api_response(list_sessions_api_response)
282298

283-
list_sessions_api_response = await api_client.async_request(
284-
http_method='GET',
285-
path=path,
286-
request_dict={},
287-
)
288-
list_sessions_api_response = _convert_api_response(
289-
list_sessions_api_response
290-
)
299+
# Handles empty response case
300+
if not converted_api_response or converted_api_response.get(
301+
'httpHeaders', None
302+
):
303+
break
291304

292-
# Handles empty response case
293-
if not list_sessions_api_response or list_sessions_api_response.get(
294-
'httpHeaders', None
295-
):
296-
return ListSessionsResponse()
305+
for api_session in converted_api_response.get('sessions', []):
306+
session = Session(
307+
app_name=app_name,
308+
user_id=user_id,
309+
id=api_session['name'].split('/')[-1],
310+
state=api_session.get('sessionState', {}),
311+
last_update_time=isoparse(api_session['updateTime']).timestamp(),
312+
)
313+
sessions.append(session)
314+
315+
page_token = converted_api_response.get('nextPageToken')
316+
if not page_token:
317+
break
297318

298-
sessions = []
299-
for api_session in list_sessions_api_response['sessions']:
300-
session = Session(
301-
app_name=app_name,
302-
user_id=user_id,
303-
id=api_session['name'].split('/')[-1],
304-
state=api_session.get('sessionState', {}),
305-
last_update_time=isoparse(api_session['updateTime']).timestamp(),
306-
)
307-
sessions.append(session)
308319
return ListSessionsResponse(sessions=sessions)
309320

310321
async def delete_session(

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Optional
2020
from typing import Tuple
2121
from unittest import mock
22+
from urllib import parse
2223

2324
from dateutil.parser import isoparse
2425
from google.adk.events.event import Event
@@ -107,6 +108,22 @@
107108
'timestamp': '2024-12-12T12:12:12.123456Z',
108109
},
109110
]
111+
MOCK_SESSION_JSON_PAGE1 = {
112+
'name': (
113+
'projects/test-project/locations/test-location/'
114+
'reasoningEngines/123/sessions/page1'
115+
),
116+
'updateTime': '2024-12-15T12:12:12.123456Z',
117+
'userId': 'user_with_pages',
118+
}
119+
MOCK_SESSION_JSON_PAGE2 = {
120+
'name': (
121+
'projects/test-project/locations/test-location/'
122+
'reasoningEngines/123/sessions/page2'
123+
),
124+
'updateTime': '2024-12-16T12:12:12.123456Z',
125+
'userId': 'user_with_pages',
126+
}
110127

111128
MOCK_SESSION = Session(
112129
app_name='123',
@@ -157,9 +174,7 @@
157174

158175

159176
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
160-
SESSIONS_REGEX = ( # %22 represents double-quotes in a URL-encoded string
161-
r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$'
162-
)
177+
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?.*$'
163178
EVENTS_REGEX = (
164179
r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events(?:\?pageToken=([^/]+))?'
165180
)
@@ -188,12 +203,28 @@ async def async_request(
188203
else:
189204
raise ValueError(f'Session not found: {session_id}')
190205
elif re.match(SESSIONS_REGEX, path):
191-
match = re.match(SESSIONS_REGEX, path)
206+
parsed_url = parse.urlparse(path)
207+
query_params = parse.parse_qs(parsed_url.query)
208+
filter_val = query_params.get('filter', [''])[0]
209+
user_id_match = re.search(r'user_id="([^"]+)"', filter_val)
210+
if not user_id_match:
211+
raise ValueError(f'Could not find user_id in filter: {filter_val}')
212+
user_id = user_id_match.group(1)
213+
214+
if user_id == 'user_with_pages':
215+
page_token = query_params.get('pageToken', [None])[0]
216+
if page_token == 'my_token':
217+
return {'sessions': [MOCK_SESSION_JSON_PAGE2]}
218+
else:
219+
return {
220+
'sessions': [MOCK_SESSION_JSON_PAGE1],
221+
'nextPageToken': 'my_token',
222+
}
192223
return {
193224
'sessions': [
194225
session
195226
for session in self.session_dict.values()
196-
if session['userId'] == match.group(2)
227+
if session['userId'] == user_id
197228
],
198229
}
199230
elif re.match(EVENTS_REGEX, path):
@@ -271,6 +302,8 @@ def mock_get_api_client():
271302
'1': MOCK_SESSION_JSON_1,
272303
'2': MOCK_SESSION_JSON_2,
273304
'3': MOCK_SESSION_JSON_3,
305+
'page1': MOCK_SESSION_JSON_PAGE1,
306+
'page2': MOCK_SESSION_JSON_PAGE2,
274307
}
275308
api_client.event_dict = {
276309
'1': (MOCK_EVENT_JSON, None),
@@ -358,6 +391,18 @@ async def test_list_sessions():
358391
assert sessions.sessions[1].id == '2'
359392

360393

394+
@pytest.mark.asyncio
395+
@pytest.mark.usefixtures('mock_get_api_client')
396+
async def test_list_sessions_with_pagination():
397+
session_service = mock_vertex_ai_session_service()
398+
sessions = await session_service.list_sessions(
399+
app_name='123', user_id='user_with_pages'
400+
)
401+
assert len(sessions.sessions) == 2
402+
assert sessions.sessions[0].id == 'page1'
403+
assert sessions.sessions[1].id == 'page2'
404+
405+
361406
@pytest.mark.asyncio
362407
@pytest.mark.usefixtures('mock_get_api_client')
363408
async def test_create_session():

0 commit comments

Comments
 (0)