diff --git a/tests/test_grpc.py b/tests/test_grpc.py index 15462cfb6..3a3ef849c 100644 --- a/tests/test_grpc.py +++ b/tests/test_grpc.py @@ -261,8 +261,9 @@ def test_message_loop(mock_heartbeat_request, _mock_sleep, _mock_event): channel = mock.MagicMock() terminate_event = threading.Event() state_record = StateRecord() + participant_id = "123" - message_loop(channel, state_record, terminate_event) + message_loop(channel, participant_id, state_record, terminate_event) # check that the heartbeat is sent exactly twice expected_call = mock.call(round=-1, state=State.READY) @@ -293,9 +294,9 @@ def test_start_training_round(coordinator_service): # simulate a participant communicating with coordinator via channel with grpc.insecure_channel("localhost:50051") as channel: # we need to rendezvous before we can send any other requests - rendezvous(channel) + rendezvous(channel, participant_id="123") # call StartTrainingRound service method on coordinator - epochs, epoch_base = start_training_round(channel) + epochs, epoch_base = start_training_round(channel, participant_id="123") # check global model received assert epochs == 5 @@ -364,7 +365,7 @@ def test_end_training_round( with grpc.insecure_channel("localhost:50051") as channel: # we first need to rendezvous before we can send any other request - rendezvous(channel) + rendezvous(channel, participant_id="123") # call EndTrainingRound service method on coordinator participant_store.write_weights("participant1", 0, test_weights) end_training_round( diff --git a/xain_fl/coordinator/coordinator.py b/xain_fl/coordinator/coordinator.py index e3ab8a042..274181d02 100644 --- a/xain_fl/coordinator/coordinator.py +++ b/xain_fl/coordinator/coordinator.py @@ -275,7 +275,8 @@ def select_outstanding(self) -> List[str]: frac = num_outstanding / len(pool) self.controller.fraction_of_participants = frac - return self.controller.select_ids(list(pool)) + outstanding: List[str] = self.controller.select_ids(list(pool)) + return outstanding def _handle_rendezvous( self, _message: RendezvousRequest, participant_id: str