|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | - |
15 | 14 | import copy |
| 15 | +import datetime |
16 | 16 | import re |
17 | 17 | import types |
18 | 18 | from typing import Any |
|
130 | 130 | 'user_id': 'user_with_pages', |
131 | 131 | } |
132 | 132 |
|
| 133 | +MOCK_SESSION_JSON_5 = { |
| 134 | + 'name': ( |
| 135 | + 'projects/test-project/locations/test-location/' |
| 136 | + 'reasoningEngines/123/sessions/5' |
| 137 | + ), |
| 138 | + 'update_time': '2024-12-12T12:15:12.123456Z', |
| 139 | + 'user_id': 'user_with_many_events', |
| 140 | +} |
| 141 | + |
| 142 | + |
| 143 | +def _generate_mock_events_for_session_5(num_events): |
| 144 | + events = [] |
| 145 | + start_time = isoparse('2024-12-12T12:12:12.123456Z') |
| 146 | + for i in range(num_events): |
| 147 | + event_time = start_time + datetime.timedelta(microseconds=i * 1000) |
| 148 | + events.append({ |
| 149 | + 'name': ( |
| 150 | + 'projects/test-project/locations/test-location/' |
| 151 | + f'reasoningEngines/123/sessions/5/events/{i}' |
| 152 | + ), |
| 153 | + 'invocation_id': f'invocation_{i}', |
| 154 | + 'author': 'user_with_many_events', |
| 155 | + 'timestamp': event_time.isoformat().replace('+00:00', 'Z'), |
| 156 | + }) |
| 157 | + return events |
| 158 | + |
| 159 | + |
| 160 | +MANY_EVENTS_COUNT = 200 |
| 161 | +MOCK_EVENTS_JSON_5 = _generate_mock_events_for_session_5(MANY_EVENTS_COUNT) |
| 162 | + |
133 | 163 | MOCK_SESSION = Session( |
134 | 164 | app_name='123', |
135 | 165 | user_id='user', |
@@ -228,6 +258,11 @@ def _convert_to_object(data): |
228 | 258 | return data |
229 | 259 |
|
230 | 260 |
|
| 261 | +async def to_async_iterator(data): |
| 262 | + for item in data: |
| 263 | + yield item |
| 264 | + |
| 265 | + |
231 | 266 | class MockAsyncClient: |
232 | 267 | """Mocks the API Client.""" |
233 | 268 |
|
@@ -330,7 +365,7 @@ async def _list_events(self, name: str, **kwargs): |
330 | 365 | for event in events |
331 | 366 | if isoparse(event['timestamp']) >= after_timestamp |
332 | 367 | ] |
333 | | - return [_convert_to_object(event) for event in events] |
| 368 | + return to_async_iterator([_convert_to_object(event) for event in events]) |
334 | 369 |
|
335 | 370 | async def _append_event( |
336 | 371 | self, |
@@ -496,6 +531,22 @@ async def test_get_session_with_after_timestamp_filter(): |
496 | 531 | assert session.events[0].id == '456' |
497 | 532 |
|
498 | 533 |
|
| 534 | +@pytest.mark.asyncio |
| 535 | +@pytest.mark.usefixtures('mock_get_api_client') |
| 536 | +async def test_get_session_with_many_events(mock_api_client_instance): |
| 537 | + mock_api_client_instance.session_dict['5'] = MOCK_SESSION_JSON_5 |
| 538 | + mock_api_client_instance.event_dict['5'] = ( |
| 539 | + copy.deepcopy(MOCK_EVENTS_JSON_5), |
| 540 | + None, |
| 541 | + ) |
| 542 | + session_service = mock_vertex_ai_session_service() |
| 543 | + session = await session_service.get_session( |
| 544 | + app_name='123', user_id='user_with_many_events', session_id='5' |
| 545 | + ) |
| 546 | + assert session is not None |
| 547 | + assert len(session.events) == MANY_EVENTS_COUNT |
| 548 | + |
| 549 | + |
499 | 550 | @pytest.mark.asyncio |
500 | 551 | @pytest.mark.usefixtures('mock_get_api_client') |
501 | 552 | async def test_list_sessions(): |
@@ -524,7 +575,13 @@ async def test_list_sessions_all_users(): |
524 | 575 | session_service = mock_vertex_ai_session_service() |
525 | 576 | sessions = await session_service.list_sessions(app_name='123', user_id=None) |
526 | 577 | assert len(sessions.sessions) == 5 |
527 | | - assert {s.id for s in sessions.sessions} == {'1', '2', '3', 'page1', 'page2'} |
| 578 | + assert {s.id for s in sessions.sessions} == { |
| 579 | + '1', |
| 580 | + '2', |
| 581 | + '3', |
| 582 | + 'page1', |
| 583 | + 'page2', |
| 584 | + } |
528 | 585 |
|
529 | 586 |
|
530 | 587 | @pytest.mark.asyncio |
|
0 commit comments