Skip to content

Commit

Permalink
feat(framework) Add ClientAppIo servicer test (#4001)
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng authored Aug 16, 2024
1 parent 61e8282 commit 3203446
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions src/py/flwr/client/process/clientappio_servicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,25 @@
"""Test the ClientAppIo API servicer."""

import unittest
from unittest.mock import Mock, patch

from flwr.client.process.process import pull_message, push_message
from flwr.common import Context, Message, typing
from flwr.common.serde import (
clientappstatus_from_proto,
clientappstatus_to_proto,
message_to_proto,
)
from flwr.common.serde_test import RecordMaker

# pylint:disable=E0611
from flwr.proto.clientappio_pb2 import (
PullClientAppInputsResponse,
PushClientAppOutputsResponse,
)
from flwr.proto.message_pb2 import Context as ProtoContext
from flwr.proto.run_pb2 import Run as ProtoRun

from .clientappio_servicer import (
ClientAppIoInputs,
ClientAppIoOutputs,
Expand All @@ -33,9 +48,15 @@ def setUp(self) -> None:
"""Initialize."""
self.servicer = ClientAppIoServicer()
self.maker = RecordMaker()
self.mock_stub = Mock()
self.patcher = patch(
"flwr.client.process.process.ClientAppIoStub", return_value=self.mock_stub
)
self.patcher.start()

def tearDown(self) -> None:
"""Cleanup."""
self.patcher.stop()

def test_set_inputs(self) -> None:
"""Test setting ClientApp inputs."""
Expand Down Expand Up @@ -116,3 +137,60 @@ def test_get_outputs(self) -> None:
assert output == client_output
assert self.servicer.clientapp_input is None
assert self.servicer.clientapp_output is None

def test_pull_clientapp_inputs(self) -> None:
"""Test pulling messages from SuperNode."""
# Prepare
mock_message = Message(
metadata=self.maker.metadata(),
content=self.maker.recordset(3, 2, 1),
)
mock_response = PullClientAppInputsResponse(
message=message_to_proto(mock_message),
context=ProtoContext(node_id=123),
run=ProtoRun(run_id=61016, fab_id="mock/mock", fab_version="v1.0.0"),
)
self.mock_stub.PullClientAppInputs.return_value = mock_response

# Execute
message, context, run = pull_message(self.mock_stub, token=456)

# Assert
self.mock_stub.PullClientAppInputs.assert_called_once()
self.assertEqual(len(message.content.parameters_records), 3)
self.assertEqual(len(message.content.metrics_records), 2)
self.assertEqual(len(message.content.configs_records), 1)
self.assertEqual(context.node_id, 123)
self.assertEqual(run.run_id, 61016)
self.assertEqual(run.fab_id, "mock/mock")
self.assertEqual(run.fab_version, "v1.0.0")

def test_push_clientapp_outputs(self) -> None:
"""Test pushing messages to SuperNode."""
# Prepare
message = Message(
metadata=self.maker.metadata(),
content=self.maker.recordset(2, 2, 1),
)
context = Context(
node_id=1,
node_config={"nodeconfig1": 4.2},
state=self.maker.recordset(2, 2, 1),
run_config={"runconfig1": 6.1},
)
code = typing.ClientAppOutputCode.SUCCESS
status_proto = clientappstatus_to_proto(
status=typing.ClientAppOutputStatus(code=code, message="SUCCESS"),
)
mock_response = PushClientAppOutputsResponse(status=status_proto)
self.mock_stub.PushClientAppOutputs.return_value = mock_response

# Execute
res = push_message(
stub=self.mock_stub, token=789, message=message, context=context
)
status = clientappstatus_from_proto(res.status)

# Assert
self.mock_stub.PushClientAppOutputs.assert_called_once()
self.assertEqual(status.message, "SUCCESS")

0 comments on commit 3203446

Please sign in to comment.