|
19 | 19 | from typing import Optional |
20 | 20 | from typing import Tuple |
21 | 21 | from unittest import mock |
| 22 | +from urllib import parse |
22 | 23 |
|
23 | 24 | from dateutil.parser import isoparse |
24 | 25 | from google.adk.events.event import Event |
|
107 | 108 | 'timestamp': '2024-12-12T12:12:12.123456Z', |
108 | 109 | }, |
109 | 110 | ] |
| 111 | +MOCK_SESSION_JSON_PAGE1 = { |
| 112 | + 'name': ( |
| 113 | + 'projects/test-project/locations/test-location/' |
| 114 | + 'reasoningEngines/123/sessions/page1' |
| 115 | + ), |
| 116 | + 'updateTime': '2024-12-15T12:12:12.123456Z', |
| 117 | + 'userId': 'user_with_pages', |
| 118 | +} |
| 119 | +MOCK_SESSION_JSON_PAGE2 = { |
| 120 | + 'name': ( |
| 121 | + 'projects/test-project/locations/test-location/' |
| 122 | + 'reasoningEngines/123/sessions/page2' |
| 123 | + ), |
| 124 | + 'updateTime': '2024-12-16T12:12:12.123456Z', |
| 125 | + 'userId': 'user_with_pages', |
| 126 | +} |
110 | 127 |
|
111 | 128 | MOCK_SESSION = Session( |
112 | 129 | app_name='123', |
|
157 | 174 |
|
158 | 175 |
|
159 | 176 | SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$' |
160 | | -SESSIONS_REGEX = ( # %22 represents double-quotes in a URL-encoded string |
161 | | - r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$' |
162 | | -) |
| 177 | +SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?.*$' |
163 | 178 | EVENTS_REGEX = ( |
164 | 179 | r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events(?:\?pageToken=([^/]+))?' |
165 | 180 | ) |
@@ -188,12 +203,28 @@ async def async_request( |
188 | 203 | else: |
189 | 204 | raise ValueError(f'Session not found: {session_id}') |
190 | 205 | elif re.match(SESSIONS_REGEX, path): |
191 | | - match = re.match(SESSIONS_REGEX, path) |
| 206 | + parsed_url = parse.urlparse(path) |
| 207 | + query_params = parse.parse_qs(parsed_url.query) |
| 208 | + filter_val = query_params.get('filter', [''])[0] |
| 209 | + user_id_match = re.search(r'user_id="([^"]+)"', filter_val) |
| 210 | + if not user_id_match: |
| 211 | + raise ValueError(f'Could not find user_id in filter: {filter_val}') |
| 212 | + user_id = user_id_match.group(1) |
| 213 | + |
| 214 | + if user_id == 'user_with_pages': |
| 215 | + page_token = query_params.get('pageToken', [None])[0] |
| 216 | + if page_token == 'my_token': |
| 217 | + return {'sessions': [MOCK_SESSION_JSON_PAGE2]} |
| 218 | + else: |
| 219 | + return { |
| 220 | + 'sessions': [MOCK_SESSION_JSON_PAGE1], |
| 221 | + 'nextPageToken': 'my_token', |
| 222 | + } |
192 | 223 | return { |
193 | 224 | 'sessions': [ |
194 | 225 | session |
195 | 226 | for session in self.session_dict.values() |
196 | | - if session['userId'] == match.group(2) |
| 227 | + if session['userId'] == user_id |
197 | 228 | ], |
198 | 229 | } |
199 | 230 | elif re.match(EVENTS_REGEX, path): |
@@ -271,6 +302,8 @@ def mock_get_api_client(): |
271 | 302 | '1': MOCK_SESSION_JSON_1, |
272 | 303 | '2': MOCK_SESSION_JSON_2, |
273 | 304 | '3': MOCK_SESSION_JSON_3, |
| 305 | + 'page1': MOCK_SESSION_JSON_PAGE1, |
| 306 | + 'page2': MOCK_SESSION_JSON_PAGE2, |
274 | 307 | } |
275 | 308 | api_client.event_dict = { |
276 | 309 | '1': (MOCK_EVENT_JSON, None), |
@@ -358,6 +391,18 @@ async def test_list_sessions(): |
358 | 391 | assert sessions.sessions[1].id == '2' |
359 | 392 |
|
360 | 393 |
|
| 394 | +@pytest.mark.asyncio |
| 395 | +@pytest.mark.usefixtures('mock_get_api_client') |
| 396 | +async def test_list_sessions_with_pagination(): |
| 397 | + session_service = mock_vertex_ai_session_service() |
| 398 | + sessions = await session_service.list_sessions( |
| 399 | + app_name='123', user_id='user_with_pages' |
| 400 | + ) |
| 401 | + assert len(sessions.sessions) == 2 |
| 402 | + assert sessions.sessions[0].id == 'page1' |
| 403 | + assert sessions.sessions[1].id == 'page2' |
| 404 | + |
| 405 | + |
361 | 406 | @pytest.mark.asyncio |
362 | 407 | @pytest.mark.usefixtures('mock_get_api_client') |
363 | 408 | async def test_create_session(): |
|
0 commit comments