diff --git a/python-sdk/exospherehost/runtime.py b/python-sdk/exospherehost/runtime.py index 04c7b7b5..8ee8e0c1 100644 --- a/python-sdk/exospherehost/runtime.py +++ b/python-sdk/exospherehost/runtime.py @@ -1,5 +1,7 @@ import asyncio import os +import multiprocessing + from asyncio import Queue, sleep from typing import List, Dict @@ -34,7 +36,7 @@ class Runtime: key (str | None, optional): API key for authentication. If not provided, will use the EXOSPHERE_API_KEY environment variable. batch_size (int, optional): Number of states to fetch per poll. Defaults to 16. - workers (int, optional): Number of concurrent worker tasks. Defaults to 4. + thread_count (int, optional): Number of concurrent worker threads. Defaults to 4. state_manage_version (str, optional): State manager API version. Defaults to "v0". poll_interval (int, optional): Seconds between polling for new states. Defaults to 1. @@ -47,13 +49,13 @@ class Runtime: runtime.start() """ - def __init__(self, namespace: str, name: str, nodes: List[type[BaseNode]], state_manager_uri: str | None = None, key: str | None = None, batch_size: int = 16, workers: int = 4, state_manage_version: str = "v0", poll_interval: int = 1): + def __init__(self, namespace: str, name: str, nodes: List[type[BaseNode]], state_manager_uri: str | None = None, key: str | None = None, batch_size: int = 16, thread_count: int = 4, state_manage_version: str = "v0", poll_interval: int = 1): self._name = name self._namespace = namespace self._key = key self._batch_size = batch_size self._state_queue = Queue(maxsize=2*batch_size) - self._workers = workers + self._thread_count = thread_count self._nodes = nodes self._node_names = [node.__name__ for node in nodes] self._state_manager_uri = state_manager_uri @@ -81,13 +83,13 @@ def _validate_runtime(self): Validate runtime configuration. Raises: - ValueError: If batch_size or workers is less than 1, or if required + ValueError: If batch_size or thread_count is less than 1, or if required configuration (state_manager_uri, key) is not provided. """ if self._batch_size < 1: raise ValueError("Batch size should be at least 1") - if self._workers < 1: - raise ValueError("Workers should be at least 1") + if self._thread_count < 1: + raise ValueError("Thread count should be at least 1") if self._state_manager_uri is None: raise ValueError("State manager URI is not set") if self._key is None: @@ -306,9 +308,9 @@ def _validate_nodes(self): if len(errors) > 0: raise ValueError("Following errors while validating nodes: " + "\n".join(errors)) - async def _worker(self): + async def _worker_thread(self): """ - Worker task that processes states from the queue. + Worker thread that processes states from the queue. Continuously fetches states from the queue, executes the corresponding node, and notifies the state manager of the result. @@ -319,7 +321,7 @@ async def _worker(self): try: node = self._node_mapping[state["node_name"]] secrets = await self._get_secrets(state["state_id"]) - outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets["secrets"])) + outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets["secrets"])) # type: ignore if outputs is None: outputs = [] @@ -338,7 +340,7 @@ async def _start(self): """ Start the runtime event loop. - Registers nodes, starts the polling and worker tasks, and runs until stopped. + Registers nodes, starts the polling and worker threads, and runs until stopped. Raises: RuntimeError: If the runtime is not connected (no nodes registered). @@ -346,9 +348,9 @@ async def _start(self): await self._register() poller = asyncio.create_task(self._enqueue()) - worker_tasks = [asyncio.create_task(self._worker()) for _ in range(self._workers)] + worker_threads = [asyncio.create_task(self._worker_thread()) for _ in range(self._thread_count)] - await asyncio.gather(poller, *worker_tasks) + await asyncio.gather(poller, *worker_threads) def start(self): """ diff --git a/python-sdk/tests/test_runtime_validation.py b/python-sdk/tests/test_runtime_validation.py index d135b258..53ca2620 100644 --- a/python-sdk/tests/test_runtime_validation.py +++ b/python-sdk/tests/test_runtime_validation.py @@ -15,11 +15,11 @@ class Secrets(BaseModel): api_key: str async def execute(self): - return self.Outputs(message=f"hi {self.inputs.name}") + return self.Outputs(message=f"hi {self.inputs.name}") # type: ignore class BadNodeWrongInputsBase(BaseNode): - Inputs = object # not a pydantic BaseModel + Inputs = object # not a pydantic BaseModel # type: ignore class Outputs(BaseModel): message: str class Secrets(BaseModel): @@ -62,7 +62,7 @@ def test_runtime_invalid_params_raises(monkeypatch): with pytest.raises(ValueError): Runtime(namespace="ns", name="rt", nodes=[GoodNode], batch_size=0) with pytest.raises(ValueError): - Runtime(namespace="ns", name="rt", nodes=[GoodNode], workers=0) + Runtime(namespace="ns", name="rt", nodes=[GoodNode], thread_count=0) def test_node_validation_errors(monkeypatch):