Skip to content

Commit 6a5eac0

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Allow passing extra kwargs to create_session of VertexAiSessionService
This can be used to set ttl and other configs. PiperOrigin-RevId: 821782343
1 parent 0b73a69 commit 6a5eac0

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,23 @@ async def create_session(
7373
user_id: str,
7474
state: Optional[dict[str, Any]] = None,
7575
session_id: Optional[str] = None,
76+
**kwargs: Any,
7677
) -> Session:
78+
"""Creates a new session.
79+
80+
Args:
81+
app_name: The name of the application.
82+
user_id: The ID of the user.
83+
state: The initial state of the session.
84+
session_id: The ID of the session.
85+
**kwargs: Additional arguments to pass to the session creation. E.g. set
86+
expire_time='2025-10-01T00:00:00Z' to set the session expiration time.
87+
See https://cloud.google.com/vertex-ai/generative-ai/docs/reference/rest/v1beta1/projects.locations.reasoningEngines.sessions
88+
for more details.
89+
Returns:
90+
The created session.
91+
"""
92+
7793
if session_id:
7894
raise ValueError(
7995
'User-provided Session id is not supported for'
@@ -84,6 +100,7 @@ async def create_session(
84100
api_client = self._get_api_client()
85101

86102
config = {'session_state': state} if state else {}
103+
config.update(kwargs)
87104

88105
if _is_vertex_express_mode(self._project, self._location):
89106
config['wait_for_completion'] = False

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def __init__(self) -> None:
242242
self.agent_engines.sessions.create.side_effect = self._create_session
243243
self.agent_engines.sessions.events.list.side_effect = self._list_events
244244
self.agent_engines.sessions.events.append.side_effect = self._append_event
245+
self.last_create_session_config: dict[str, Any] = {}
245246

246247
def _get_session(self, name: str):
247248
session_id = name.split('/')[-1]
@@ -275,6 +276,7 @@ def _delete_session(self, name: str):
275276
self.session_dict.pop(session_id)
276277

277278
def _create_session(self, name: str, user_id: str, config: dict[str, Any]):
279+
self.last_create_session_config = config
278280
new_session_id = '4'
279281
self.session_dict[new_session_id] = {
280282
'name': (
@@ -360,7 +362,8 @@ def mock_vertex_ai_session_service(agent_engine_id: Optional[str] = None):
360362

361363

362364
@pytest.fixture
363-
def mock_get_api_client():
365+
def mock_api_client_instance():
366+
"""Creates a mock API client instance for testing."""
364367
api_client = MockApiClient()
365368
api_client.session_dict = {
366369
'1': MOCK_SESSION_JSON_1,
@@ -373,9 +376,15 @@ def mock_get_api_client():
373376
'1': (copy.deepcopy(MOCK_EVENT_JSON), None),
374377
'2': (copy.deepcopy(MOCK_EVENT_JSON_2), 'my_token'),
375378
}
379+
return api_client
380+
381+
382+
@pytest.fixture
383+
def mock_get_api_client(mock_api_client_instance):
384+
"""Mocks the _get_api_client method to return a mock API client."""
376385
with mock.patch(
377386
'google.adk.sessions.vertex_ai_session_service.VertexAiSessionService._get_api_client',
378-
return_value=api_client,
387+
return_value=mock_api_client_instance,
379388
):
380389
yield
381390

@@ -521,6 +530,21 @@ async def test_create_session_with_custom_session_id():
521530
)
522531

523532

533+
@pytest.mark.asyncio
534+
@pytest.mark.usefixtures('mock_get_api_client')
535+
async def test_create_session_with_custom_config(mock_api_client_instance):
536+
session_service = mock_vertex_ai_session_service()
537+
538+
expire_time = '2025-12-12T12:12:12.123456Z'
539+
await session_service.create_session(
540+
app_name='123', user_id='user', expire_time=expire_time
541+
)
542+
assert (
543+
mock_api_client_instance.last_create_session_config['expire_time']
544+
== expire_time
545+
)
546+
547+
524548
@pytest.mark.asyncio
525549
@pytest.mark.usefixtures('mock_get_api_client')
526550
async def test_append_event():

0 commit comments

Comments
 (0)