diff --git a/src/py/flwr/server/superlink/driver/serverappio_servicer.py b/src/py/flwr/server/superlink/driver/serverappio_servicer.py index f52129a2ba11..f4183763cae5 100644 --- a/src/py/flwr/server/superlink/driver/serverappio_servicer.py +++ b/src/py/flwr/server/superlink/driver/serverappio_servicer.py @@ -159,6 +159,9 @@ def PushTaskIns( for task_ins in request.task_ins_list: validation_errors = validate_task_ins_or_res(task_ins) _raise_if(bool(validation_errors), ", ".join(validation_errors)) + _raise_if( + request.run_id != task_ins.run_id, "`task_ins` has mismatched `run_id`" + ) # Store each TaskIns task_ids: list[Optional[UUID]] = [] @@ -193,6 +196,12 @@ def PullTaskRes( # Read from state task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids) + # Validate request + for task_res in task_res_list: + _raise_if( + request.run_id != task_res.run_id, "`task_res` has mismatched `run_id`" + ) + # Delete the TaskIns/TaskRes pairs if TaskRes is found task_ins_ids_to_delete = { UUID(task_res.task.ancestry[0]) for task_res in task_res_list