Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class ACADynamicSessionsCodeExecutor(CodeExecutor):
directory is a temporal directory.
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
suppress_result_output bool: By default the executor will attach any result info in the execution response to the result outpu. Set this to True to prevent this.
session_id (str): The session id for the code execution (passed to Dynamic Sessions). If None, a new session id will be generated. Default is None. Note this value will be reset when calling `restart`

.. note::
Using the current directory (".") as working directory is deprecated. Using it will raise a deprecation warning.
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
] = [],
functions_module: str = "functions",
suppress_result_output: bool = False,
session_id: Optional[str] = None,
):
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
Expand Down Expand Up @@ -141,7 +143,7 @@ def __init__(

self._pool_management_endpoint = pool_management_endpoint
self._access_token: str | None = None
self._session_id: str = str(uuid4())
self._session_id: str = session_id or str(uuid4())
self._available_packages: set[str] | None = None
self._credential: TokenProvider = credential
# cwd needs to be set to /mnt/data to properly read uploaded files and download written files
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@
POOL_ENDPOINT = os.getenv(ENVIRON_KEY_AZURE_POOL_ENDPOINT)


def test_session_id_preserved_if_passed() -> None:
executor = ACADynamicSessionsCodeExecutor(
pool_management_endpoint="fake-endpoint", credential=DefaultAzureCredential()
)
session_id = "test_session_id"
executor._session_id = session_id # type: ignore[reportPrivateUsage]
assert executor._session_id == session_id # type: ignore[reportPrivateUsage]


def test_session_id_generated_if_not_passed() -> None:
executor = ACADynamicSessionsCodeExecutor(
pool_management_endpoint="fake-endpoint", credential=DefaultAzureCredential()
)
assert executor._session_id is not None # type: ignore[reportPrivateUsage]
assert len(executor._session_id) > 0 # type: ignore[reportPrivateUsage]


@pytest.mark.skipif(
not POOL_ENDPOINT,
reason="do not run if pool endpoint is not defined",
Expand Down
Loading