Skip to content

Commit

Permalink
Integrate Upstream Changes. Closes #157
Browse files Browse the repository at this point in the history
  • Loading branch information
umesh-timalsina committed Sep 19, 2023
1 parent 7ab4d07 commit 5f86e79
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 24 deletions.
10 changes: 8 additions & 2 deletions chimerapy/orchestrator/models/cluster_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,20 @@ class NodeState(BaseModel):
"SAVED",
"SHUTDOWN",
]
registered_methods: Dict[str, Any] = Field(default_factory=dict)
registered_methods: Dict[str, RegisteredMethod] = Field(
default_factory=dict
)
logdir: Optional[str] = None
diagnostics: NodeDiagnostics

@classmethod
def from_cp_node_state(cls, node_state: _NodeState):
node_state_dict = node_state.to_dict()
node_state_dict["logdir"] = str(node_state_dict["logdir"]) if node_state_dict["logdir"] is not None else None
node_state_dict["logdir"] = (
str(node_state_dict["logdir"])
if node_state_dict["logdir"] is not None
else None
)
return cls(**node_state.to_dict())

model_config: ClassVar[ConfigDict] = ConfigDict(frozen=True, extra="forbid")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def update_network_status(self) -> None:

def is_zeroconf_discovery_enabled(self) -> bool:
"""Check if zeroconf discovery is enabled."""
return self._manager.services.zeroconf.enabled
return self._manager.zeroconf_service.enabled

async def instantiate_pipeline(
self, pipeline_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def add_client(
await self.updater.add_client(q)

if message is not None:
await q.put(message.dict())
await q.put(message.model_dump(mode="json"))

async def remove_client(self, q: asyncio.Queue) -> None:
"""Remove a client queue from the broadcaster."""
Expand Down Expand Up @@ -139,7 +139,7 @@ async def broadcast_updates(self) -> None:
else:
msg = None
if msg is not None:
msg_dict = msg.dict()
msg_dict = msg.model_dump(mode="json")
await self.updater.put_update(msg_dict)
if msg and msg.signal is UpdateMessageType.SHUTDOWN:
break
Expand All @@ -165,7 +165,7 @@ async def put_update(self, msg: Dict[str, Any]) -> None:
UpdateMessageType.NETWORK_UPDATE,
self.zeroconf_enabled,
)
await self.updater.put_update(update_msg.dict())
await self.updater.put_update(update_msg.model_dump(mode="json"))

@staticmethod
def is_cluster_update_message(msg: Dict[str, Any]) -> bool:
Expand Down
144 changes: 129 additions & 15 deletions chimerapy/orchestrator/tests/models/test_cluster_models.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
from pathlib import Path

import pytest

from chimerapy.engine.states import (
ManagerState as _ManagerState,
WorkerState as _WorkerState,
)
from chimerapy.engine.states import (
NodeState as _NodeState,
)

from chimerapy.engine.states import (
RegisteredMethod as _RegisteredMethod,
)
from chimerapy.engine.states import (
WorkerState as _WorkerState,
)
from chimerapy.orchestrator.models.cluster_models import (
ClusterState,
NodeState,
RegisteredMethod,
WorkerState,
ClusterState,
)

from pathlib import Path
from chimerapy.orchestrator.tests.base_test import BaseTest

import pytest


class TestClusterModels(BaseTest):
@pytest.fixture(scope="class")
Expand All @@ -32,7 +39,7 @@ def m_populated(self):
port=55000,
workers={},
log_sink_enabled=True,
logs_subscription_port=55001
logs_subscription_port=55001,
)

@pytest.fixture(scope="class")
Expand All @@ -51,21 +58,84 @@ def m_w_empty(self):
port=55002,
ip="192.168.2.1",
nodes={},
tempfolder=Path("/tmp24/w1")
tempfolder=Path("/tmp24/w1"),
),
"w2": _WorkerState(
id="w2",
name="worker2",
port=55003,
ip="192.168.2.2",
nodes={},
tempfolder=Path("/tmp24/w2")
)
}
tempfolder=Path("/tmp24/w2"),
),
},
)

@pytest.fixture(scope="class")
def m_w_populated(self):
return _ManagerState(
id="manager1",
logdir=Path("/tmp24"),
ip="192.168.2.0",
port=55000,
log_sink_enabled=True,
logs_subscription_port=55001,
workers={
"w1": _WorkerState(
id="w1",
name="worker1",
port=55002,
ip="192.168.2.1",
nodes={
"n1": _NodeState(
id="n1",
name="node1",
port=55004,
fsm="NULL",
registered_methods={
"m1": _RegisteredMethod(
name="func1",
style="concurrent",
params={
"p1": "int",
},
)
},
)
},
tempfolder=Path("/tmp24/w1"),
),
"w2": _WorkerState(
id="w2",
name="worker2",
port=55003,
ip="192.168.2.2",
nodes={
"n2": _NodeState(
id="n2",
name="node2",
port=55005,
fsm="INITIALIZED",
registered_methods={
"m2": _RegisteredMethod(
name="func2",
style="concurrent",
params={
"p2": "float",
},
)
},
)
},
tempfolder=Path("/tmp24/w2"),
),
},
)

def test_m_empty(self, m_empty):
manager_state = ClusterState.from_cp_manager_state(m_empty, zeroconf_discovery=False)
manager_state = ClusterState.from_cp_manager_state(
m_empty, zeroconf_discovery=False
)
assert manager_state.id == "manager1"
assert manager_state.logdir == str(Path.cwd())
assert manager_state.log_sink_enabled is False
Expand All @@ -75,7 +145,9 @@ def test_m_empty(self, m_empty):
assert manager_state.workers == {}

def test_m_populated(self, m_populated):
manager_state = ClusterState.from_cp_manager_state(m_populated, zeroconf_discovery=True)
manager_state = ClusterState.from_cp_manager_state(
m_populated, zeroconf_discovery=True
)
assert manager_state.id == "manager1"
assert manager_state.logdir == "/tmp24"
assert manager_state.log_sink_enabled is True
Expand All @@ -86,7 +158,9 @@ def test_m_populated(self, m_populated):
assert manager_state.zeroconf_discovery is True

def test_m_w_empty(self, m_w_empty):
manager_state = ClusterState.from_cp_manager_state(m_w_empty, zeroconf_discovery=True)
manager_state = ClusterState.from_cp_manager_state(
m_w_empty, zeroconf_discovery=True
)
assert len(manager_state.workers) == 2
w1 = manager_state.workers["w1"]
assert w1.id == "w1"
Expand All @@ -106,4 +180,44 @@ def test_m_w_empty(self, m_w_empty):
assert w2.nodes == {}
assert isinstance(w2, WorkerState)

def test_m_w_populated(self, m_w_populated):
manager_state = ClusterState.from_cp_manager_state(
m_w_populated, zeroconf_discovery=True
)
assert len(manager_state.workers) == 2
w1 = manager_state.workers["w1"]

n1 = w1.nodes["n1"]
assert isinstance(n1, NodeState)
assert n1.id == "n1"
assert n1.name == "node1"
assert n1.port == 55004
assert n1.fsm == "NULL"

assert n1.registered_methods == {
"m1": RegisteredMethod(
name="func1",
style="concurrent",
params={
"p1": "int",
},
)
}

w2 = manager_state.workers["w2"]
n2 = w2.nodes["n2"]
assert isinstance(n2, NodeState)
assert n2.id == "n2"
assert n2.name == "node2"
assert n2.port == 55005
assert n2.fsm == "INITIALIZED"

assert n2.registered_methods == {
"m2": RegisteredMethod(
name="func2",
style="concurrent",
params={
"p2": "float",
},
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def dummy_pipeline_config(self) -> ChimeraPyPipelineConfig:

def test_pipeline_config(self, dummy_pipeline_config):
assert dummy_pipeline_config.name == "Pipeline"
assert dummy_pipeline_config.description == "A pipeline"
assert dummy_pipeline_config.description == ""
assert dummy_pipeline_config.runtime == 2000

def test_worker_config(self, dummy_pipeline_config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,15 @@ def test_get_network(self, cluster_manager):
assert cluster_manager.get_network().map(
lambda n: n.to_dict()
).unwrap() == {
"id": "Manager",
"id": cluster_manager._manager.state.id, # pylint: disable=protected-access
"workers": {},
"ip": get_ip_address(),
"port": cluster_manager._manager.port, # pylint: disable=protected-access
"workers": {},
"logs_subscription_port": None,
"log_sink_enabled": True,
"logdir": str(
cluster_manager._manager.logdir
), # pylint: disable=protected-access
}

@pytest.mark.timeout(30)
Expand Down

0 comments on commit 5f86e79

Please sign in to comment.