Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions python-sdk/exospherehost/runtime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import os
import multiprocessing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick (assertive)

Confirm intent for new multiprocessing import (currently unused).

The new import isn’t referenced in this module. If you don’t plan to introduce multiprocessing-based code in this PR, please remove it to avoid confusion. If it’s a placeholder for imminent work, ignore this note.

🤖 Prompt for AI Agents
In python-sdk/exospherehost/runtime.py around line 3, the new import
multiprocessing is unused; remove the import to avoid confusion and linter
warnings (or, if you intended to add multiprocessing-based functionality in this
PR, add the corresponding code that uses multiprocessing and update
tests/documentation accordingly) — default action: delete the unused "import
multiprocessing" line.


from asyncio import Queue, sleep
from typing import List, Dict

Expand Down Expand Up @@ -34,7 +36,7 @@ class Runtime:
key (str | None, optional): API key for authentication.
If not provided, will use the EXOSPHERE_API_KEY environment variable.
batch_size (int, optional): Number of states to fetch per poll. Defaults to 16.
workers (int, optional): Number of concurrent worker tasks. Defaults to 4.
thread_count (int, optional): Number of concurrent worker threads. Defaults to 4.
state_manage_version (str, optional): State manager API version. Defaults to "v0".
poll_interval (int, optional): Seconds between polling for new states. Defaults to 1.

Expand All @@ -47,13 +49,13 @@ class Runtime:
runtime.start()
"""

def __init__(self, namespace: str, name: str, nodes: List[type[BaseNode]], state_manager_uri: str | None = None, key: str | None = None, batch_size: int = 16, workers: int = 4, state_manage_version: str = "v0", poll_interval: int = 1):
def __init__(self, namespace: str, name: str, nodes: List[type[BaseNode]], state_manager_uri: str | None = None, key: str | None = None, batch_size: int = 16, thread_count: int = 4, state_manage_version: str = "v0", poll_interval: int = 1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Renaming workers to thread_count is potentially misleading. In asyncio, concurrency is handled by tasks running on a single OS thread's event loop. The term 'worker' is standard for describing a task that processes a queue. Using 'thread' might incorrectly suggest that the threading module and multiple OS threads are being used. It would be clearer to stick with workers or a similar term like worker_count.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

API misnomer: “thread_count” suggests OS threads, but implementation uses asyncio tasks. Change before publishing 0.0.1.

Exposing “thread_count” as a public constructor param (and naming the worker as “_worker_thread”) is misleading and will be hard to change post-release. The runtime spawns asyncio tasks (cooperative coroutines) via asyncio.create_task, not threads. Recommend renaming to worker_count (or task_count) and updating validation, docs, and call sites accordingly.

Proposed minimal API/name fixes:

-def __init__(self, namespace: str, name: str, nodes: List[type[BaseNode]], state_manager_uri: str | None = None, key: str | None = None, batch_size: int = 16, thread_count: int = 4, state_manage_version: str = "v0", poll_interval: int = 1):
+def __init__(self, namespace: str, name: str, nodes: List[type[BaseNode]], state_manager_uri: str | None = None, key: str | None = None, batch_size: int = 16, worker_count: int = 4, state_manage_version: str = "v0", poll_interval: int = 1):
@@
-        self._thread_count = thread_count
+        self._worker_count = worker_count
@@
-        if self._thread_count < 1:
-            raise ValueError("Thread count should be at least 1")
+        if self._worker_count < 1:
+            raise ValueError("Worker count should be at least 1")
@@
-        worker_tasks = [asyncio.create_task(self._worker_thread()) for _ in range(self._thread_count)]
+        worker_tasks = [asyncio.create_task(self._worker_task()) for _ in range(self._worker_count)]

Also update docstrings to reflect the terminology and avoid implying OS threads:

-        workers (int, optional): Number of concurrent worker tasks. Defaults to 4.
+        worker_count (int, optional): Number of concurrent worker tasks (asyncio tasks). Defaults to 4.
@@
-        ValueError: If configuration is invalid (e.g., missing URI or key, batch_size/workers < 1).
+        ValueError: If configuration is invalid (e.g., missing URI or key, batch_size/worker_count < 1).

If you truly intend real threads, we’d need a different design (e.g., threading or ThreadPoolExecutor, and careful event loop handoff), which is a larger change.

Run this to find lingering references that need renaming and stale “workers” mentions in code/docs:

Also applies to: 58-58, 91-92, 351-351


🏁 Script executed:

#!/bin/bash
set -euo pipefail
rg -nP -C2 '(?<!_)workers?\b|\bthread_count\b|\b_thread_count\b|\b_worker_thread\b'

Length of output: 7529


Rename thread_count to worker_count (asyncio tasks, not OS threads)

The constructor parameter and internal naming currently suggests OS threads, but the runtime spawns asyncio tasks. To avoid a breaking API change post-release, please rename all thread_count/workers references to worker_count (or task_count) before publishing v0.0.1.

Locations to update:

  • python-sdk/exospherehost/runtime.py

    • Line 52: constructor signature – thread_countworker_count
    • Line 58: self._thread_countself._worker_count
    • Lines 91–92: validation – rename and error message to "Worker count should be at least 1"
    • Line 311: private method – rename _worker_thread_worker_task
    • Line 351: spawn loop – use self._worker_count and call _worker_task()
    • Docstrings around lines 39–44: update parameter name and description (workersworker_count and clarify “asyncio tasks”)
  • python-sdk/tests/test_runtime_validation.py

    • Line 65: test invocation – change workers=0worker_count=0

Proposed diffs:

--- a/python-sdk/exospherehost/runtime.py
@@ -52,7 +52,7 @@ class Runtime:
-    def __init__(self, namespace: str, name: str, nodes: List[type[BaseNode]], state_manager_uri: str | None = None, key: str | None = None, batch_size: int = 16, thread_count: int = 4, state_manage_version: str = "v0", poll_interval: int = 1):
+    def __init__(self, namespace: str, name: str, nodes: List[type[BaseNode]], state_manager_uri: str | None = None, key: str | None = None, batch_size: int = 16, worker_count: int = 4, state_manage_version: str = "v0", poll_interval: int = 1):
@@ -58,7 +58,7 @@ class Runtime:
-        self._thread_count = thread_count
+        self._worker_count = worker_count
@@ -91,5 +91,5 @@ class Runtime:
-        if self._thread_count < 1:
-            raise ValueError("Thread count should be at least 1")
+        if self._worker_count < 1:
+            raise ValueError("Worker count should be at least 1")
@@ -311,4 +311,4 @@ class Runtime:
-    async def _worker_thread(self):
+    async def _worker_task(self):
@@ -351,4 +351,4 @@ class Runtime:
-        worker_tasks = [asyncio.create_task(self._worker_thread()) for _ in range(self._thread_count)]
+        worker_tasks = [asyncio.create_task(self._worker_task()) for _ in range(self._worker_count)]
--- a/python-sdk/tests/test_runtime_validation.py
@@ -63,7 +63,7 @@ def test_batch_size_validation():
         Runtime(namespace="ns", name="rt", nodes=[GoodNode], batch_size=0)
     with pytest.raises(ValueError):
-        Runtime(namespace="ns", name="rt", nodes=[GoodNode], workers=0)
+        Runtime(namespace="ns", name="rt", nodes=[GoodNode], worker_count=0)

Please apply these changes across docstrings and any other call sites to ensure consistency. Let me know if you need help updating remaining references.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __init__(self, namespace: str, name: str, nodes: List[type[BaseNode]], state_manager_uri: str | None = None, key: str | None = None, batch_size: int = 16, thread_count: int = 4, state_manage_version: str = "v0", poll_interval: int = 1):
# File: python-sdk/exospherehost/runtime.py
def __init__(
self,
namespace: str,
name: str,
nodes: List[type[BaseNode]],
state_manager_uri: str | None = None,
key: str | None = None,
batch_size: int = 16,
worker_count: int = 4,
state_manage_version: str = "v0",
poll_interval: int = 1,
):
# ...
self._worker_count = worker_count
# ...
# validation
- if self._thread_count < 1:
if self._worker_count < 1:
raise ValueError("Worker count should be at least 1")
# ...
async def _worker_task(self):
# ...
# spawn worker tasks
worker_tasks = [asyncio.create_task(self._worker_task()) for _ in range(self._worker_count)]
# ...
Suggested change
def __init__(self, namespace: str, name: str, nodes: List[type[BaseNode]], state_manager_uri: str | None = None, key: str | None = None, batch_size: int = 16, thread_count: int = 4, state_manage_version: str = "v0", poll_interval: int = 1):
# File: python-sdk/tests/test_runtime_validation.py
def test_batch_size_validation():
Runtime(namespace="ns", name="rt", nodes=[GoodNode], batch_size=1)
with pytest.raises(ValueError):
Runtime(namespace="ns", name="rt", nodes=[GoodNode], worker_count=0)
🤖 Prompt for AI Agents
In python-sdk/exospherehost/runtime.py around line 52 (constructor) and nearby
docstrings (lines ~39–44), rename the parameter thread_count to worker_count and
update its docstring to clarify it represents asyncio tasks; change internal
attribute self._thread_count to self._worker_count (line 58) and update
validation on lines ~91–92 to validate worker_count and use the error message
"Worker count should be at least 1"; rename the private method _worker_thread to
_worker_task (around line 311) and in the spawn loop (around line 351) use
self._worker_count and call _worker_task(); also update any other references in
this file and tests (python-sdk/tests/test_runtime_validation.py line 65: pass
worker_count=0) and adjust names in docstrings and error messages to ensure
consistency across the codebase.

self._name = name
self._namespace = namespace
self._key = key
self._batch_size = batch_size
self._state_queue = Queue(maxsize=2*batch_size)
self._workers = workers
self._thread_count = thread_count
self._nodes = nodes
self._node_names = [node.__name__ for node in nodes]
self._state_manager_uri = state_manager_uri
Expand Down Expand Up @@ -81,13 +83,13 @@ def _validate_runtime(self):
Validate runtime configuration.

Raises:
ValueError: If batch_size or workers is less than 1, or if required
ValueError: If batch_size or thread_count is less than 1, or if required
configuration (state_manager_uri, key) is not provided.
"""
if self._batch_size < 1:
raise ValueError("Batch size should be at least 1")
if self._workers < 1:
raise ValueError("Workers should be at least 1")
if self._thread_count < 1:
raise ValueError("Thread count should be at least 1")
if self._state_manager_uri is None:
raise ValueError("State manager URI is not set")
if self._key is None:
Expand Down Expand Up @@ -306,9 +308,9 @@ def _validate_nodes(self):
if len(errors) > 0:
raise ValueError("Following errors while validating nodes: " + "\n".join(errors))

async def _worker(self):
async def _worker_thread(self):
"""
Worker task that processes states from the queue.
Worker thread that processes states from the queue.

Continuously fetches states from the queue, executes the corresponding node,
and notifies the state manager of the result.
Expand All @@ -319,7 +321,7 @@ async def _worker(self):
try:
node = self._node_mapping[state["node_name"]]
secrets = await self._get_secrets(state["state_id"])
outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets["secrets"]))
outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets["secrets"])) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This # type: ignore comment was added to suppress a typing error, but this change is unrelated to the main goal of this PR (renaming workers to threads). It's best practice to keep pull requests focused on a single concern. Please move this fix to a separate PR. The typing error itself points to a potential design issue in BaseNode's type hierarchy that should be investigated.

Suggested change
outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets["secrets"])) # type: ignore
outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets["secrets"]))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Secrets shape mismatch and unnecessary type: ignore. Return a flat mapping from _get_secrets and pass it directly.

Currently, _get_secrets() is annotated as Dict[str, str] but the worker indexes secrets["secrets"], forcing a type ignore. Fix by returning the inner map from _get_secrets and using it directly:

Change the call site (adjust names if you apply the worker refactor above):

-outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets["secrets"])) # type: ignore
+outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets))

And update _get_secrets to return the flat dict:

-    async def _get_secrets(self, state_id: str) -> Dict[str, str]:
+    async def _get_secrets(self, state_id: str) -> Dict[str, str]:
@@
-            async with session.get(endpoint, headers=headers) as response: # type: ignore
+            async with session.get(endpoint, headers=headers) as response:  # type: ignore
                 res = await response.json()
@@
-                return res
+                # Expecting server payload like: {"secrets": {"KEY": "VALUE", ...}}
+                return res.get("secrets", {})

This removes the type suppression and aligns annotation with actual return.

🤖 Prompt for AI Agents
In python-sdk/exospherehost/runtime.py around line 324, the call currently
indexes secrets["secrets"] and uses a "# type: ignore" because _get_secrets
returns a nested shape; change _get_secrets to return a flat Dict[str, str] (the
inner mapping) and update its type annotation accordingly, then at this call
site pass the returned dict directly into node.Secrets(...) (and remove the "#
type: ignore"), ensuring the shapes align and type checking passes.


if outputs is None:
outputs = []
Expand All @@ -338,17 +340,17 @@ async def _start(self):
"""
Start the runtime event loop.

Registers nodes, starts the polling and worker tasks, and runs until stopped.
Registers nodes, starts the polling and worker threads, and runs until stopped.

Raises:
RuntimeError: If the runtime is not connected (no nodes registered).
"""
await self._register()

poller = asyncio.create_task(self._enqueue())
worker_tasks = [asyncio.create_task(self._worker()) for _ in range(self._workers)]
worker_threads = [asyncio.create_task(self._worker_thread()) for _ in range(self._thread_count)]

await asyncio.gather(poller, *worker_tasks)
await asyncio.gather(poller, *worker_threads)

def start(self):
"""
Expand Down
6 changes: 3 additions & 3 deletions python-sdk/tests/test_runtime_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ class Secrets(BaseModel):
api_key: str

async def execute(self):
return self.Outputs(message=f"hi {self.inputs.name}")
return self.Outputs(message=f"hi {self.inputs.name}") # type: ignore


class BadNodeWrongInputsBase(BaseNode):
Inputs = object # not a pydantic BaseModel
Inputs = object # not a pydantic BaseModel # type: ignore
class Outputs(BaseModel):
message: str
class Secrets(BaseModel):
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_runtime_invalid_params_raises(monkeypatch):
with pytest.raises(ValueError):
Runtime(namespace="ns", name="rt", nodes=[GoodNode], batch_size=0)
with pytest.raises(ValueError):
Runtime(namespace="ns", name="rt", nodes=[GoodNode], workers=0)
Runtime(namespace="ns", name="rt", nodes=[GoodNode], thread_count=0)


def test_node_validation_errors(monkeypatch):
Expand Down
Loading