diff --git a/python/packages/autogen-ext/src/autogen_ext/code_executors/azure/_azure_container_code_executor.py b/python/packages/autogen-ext/src/autogen_ext/code_executors/azure/_azure_container_code_executor.py index 41b23c658141..17c4b16a2c15 100644 --- a/python/packages/autogen-ext/src/autogen_ext/code_executors/azure/_azure_container_code_executor.py +++ b/python/packages/autogen-ext/src/autogen_ext/code_executors/azure/_azure_container_code_executor.py @@ -73,6 +73,7 @@ class ACADynamicSessionsCodeExecutor(CodeExecutor): directory is a temporal directory. functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list. suppress_result_output bool: By default the executor will attach any result info in the execution response to the result outpu. Set this to True to prevent this. + session_id (str): The session id for the code execution (passed to Dynamic Sessions). If None, a new session id will be generated. Default is None. Note this value will be reset when calling `restart` .. note:: Using the current directory (".") as working directory is deprecated. Using it will raise a deprecation warning. @@ -102,6 +103,7 @@ def __init__( ] = [], functions_module: str = "functions", suppress_result_output: bool = False, + session_id: Optional[str] = None, ): if timeout < 1: raise ValueError("Timeout must be greater than or equal to 1.") @@ -141,7 +143,7 @@ def __init__( self._pool_management_endpoint = pool_management_endpoint self._access_token: str | None = None - self._session_id: str = str(uuid4()) + self._session_id: str = session_id or str(uuid4()) self._available_packages: set[str] | None = None self._credential: TokenProvider = credential # cwd needs to be set to /mnt/data to properly read uploaded files and download written files diff --git a/python/packages/autogen-ext/tests/code_executors/test_aca_dynamic_sessions.py b/python/packages/autogen-ext/tests/code_executors/test_aca_dynamic_sessions.py index 8fa8503ca154..aa13f1549f8b 100644 --- a/python/packages/autogen-ext/tests/code_executors/test_aca_dynamic_sessions.py +++ b/python/packages/autogen-ext/tests/code_executors/test_aca_dynamic_sessions.py @@ -22,6 +22,23 @@ POOL_ENDPOINT = os.getenv(ENVIRON_KEY_AZURE_POOL_ENDPOINT) +def test_session_id_preserved_if_passed() -> None: + executor = ACADynamicSessionsCodeExecutor( + pool_management_endpoint="fake-endpoint", credential=DefaultAzureCredential() + ) + session_id = "test_session_id" + executor._session_id = session_id # type: ignore[reportPrivateUsage] + assert executor._session_id == session_id # type: ignore[reportPrivateUsage] + + +def test_session_id_generated_if_not_passed() -> None: + executor = ACADynamicSessionsCodeExecutor( + pool_management_endpoint="fake-endpoint", credential=DefaultAzureCredential() + ) + assert executor._session_id is not None # type: ignore[reportPrivateUsage] + assert len(executor._session_id) > 0 # type: ignore[reportPrivateUsage] + + @pytest.mark.skipif( not POOL_ENDPOINT, reason="do not run if pool endpoint is not defined",