Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add node_id to metadata #2912

Merged
merged 18 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def receive() -> Message:
task_id=str(uuid.uuid4()),
group_id="",
ttl="",
node_id=0,
task_type=task_type,
),
content=recordset,
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
run_id=0,
task_id="",
group_id="",
node_id=0,
ttl="",
task_type=TASK_TYPE_GET_PROPERTIES,
),
Expand All @@ -60,6 +61,7 @@
run_id=0,
task_id="",
group_id="",
node_id=0,
ttl="",
task_type="reconnect",
),
Expand Down
8 changes: 5 additions & 3 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
run_id=0,
task_id="",
group_id="",
node_id=0,
ttl="",
task_type="reconnect",
),
Expand Down Expand Up @@ -150,9 +151,10 @@ def handle_legacy_message_from_tasktype(
# Return Message
out_message = Message(
metadata=Metadata(
run_id=0, # Non-user defined
task_id="", # Non-user defined
group_id="", # Non-user defined
run_id=0,
task_id="",
group_id="",
node_id=0,
ttl="",
task_type=task_type,
),
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_client_without_get_properties() -> None:
run_id=0,
task_id=str(uuid.uuid4()),
group_id="",
node_id=0,
ttl="",
task_type=TASK_TYPE_GET_PROPERTIES,
),
Expand Down Expand Up @@ -162,6 +163,7 @@ def test_client_with_get_properties() -> None:
run_id=0,
task_id=str(uuid.uuid4()),
group_id="",
node_id=0,
ttl="",
task_type=TASK_TYPE_GET_PROPERTIES,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,14 @@ def secaggplus_mod(

# Return message
return Message(
metadata=Metadata(0, "", "", "", TASK_TYPE_FIT),
metadata=Metadata(
run_id=0,
task_id="",
group_id="",
node_id=0,
ttl="",
task_type=TASK_TYPE_FIT,
),
content=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)}),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,29 @@ def get_test_handler(

def empty_ffn(_: Message, _2: Context) -> Message:
return Message(
metadata=Metadata(0, "", "", "", TASK_TYPE_FIT),
metadata=Metadata(
run_id=0,
task_id="",
group_id="",
node_id=0,
ttl="",
task_type=TASK_TYPE_FIT,
),
content=RecordSet(),
)

app = make_ffn(empty_ffn, [secaggplus_mod])

def func(configs: Dict[str, ConfigsRecordValues]) -> Dict[str, ConfigsRecordValues]:
in_msg = Message(
metadata=Metadata(0, "", "", "", TASK_TYPE_FIT),
metadata=Metadata(
run_id=0,
task_id="",
group_id="",
node_id=0,
ttl="",
task_type=TASK_TYPE_FIT,
),
content=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(configs)}),
)
out_msg = app(in_msg, ctxt)
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/client/mod/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def app(message: Message, context: Context) -> Message:
def _get_dummy_flower_message() -> Message:
return Message(
content=RecordSet(),
metadata=Metadata(run_id=0, task_id="", group_id="", ttl="", task_type="mock"),
metadata=Metadata(
run_id=0, task_id="", group_id="", node_id=0, ttl="", task_type="mock"
),
)


Expand Down
3 changes: 3 additions & 0 deletions src/py/flwr/common/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class Metadata:
group_id : str
An identifier for grouping tasks. In some settings
this is used as the FL round.
node_id : int
An identifier for the node running a task.
ttl : str
Time-to-live for this task.
task_type : str
Expand All @@ -43,6 +45,7 @@ class Metadata:
run_id: int
task_id: str
group_id: str
node_id: int
ttl: str
task_type: str

Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def message_from_taskins(taskins: TaskIns) -> Message:
run_id=taskins.run_id,
task_id=taskins.task_id,
group_id=taskins.group_id,
node_id=taskins.task.consumer.node_id,
ttl=taskins.task.ttl,
task_type=taskins.task.task_type,
)
Expand Down Expand Up @@ -592,6 +593,7 @@ def message_from_taskres(taskres: TaskRes) -> Message:
run_id=taskres.run_id,
task_id=taskres.task_id,
group_id=taskres.group_id,
node_id=taskres.task.consumer.node_id,
ttl=taskres.task.ttl,
task_type=taskres.task.task_type,
)
Expand Down
5 changes: 5 additions & 0 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def metadata(self) -> Metadata:
run_id=self.rng.randint(0, 1 << 30),
task_id=self.get_str(64),
group_id=self.get_str(30),
node_id=self.rng.randint(0, 1 << 63),
ttl=self.get_str(10),
task_type=self.get_str(10),
)
Expand Down Expand Up @@ -309,6 +310,7 @@ def test_message_to_and_from_taskins() -> None:
run_id=0,
task_id="",
group_id="",
node_id=metadata.node_id,
ttl=metadata.ttl,
task_type=metadata.task_type,
),
Expand All @@ -320,6 +322,7 @@ def test_message_to_and_from_taskins() -> None:
taskins.run_id = metadata.run_id
taskins.task_id = metadata.task_id
taskins.group_id = metadata.group_id
taskins.task.consumer.node_id = metadata.node_id
deserialized = message_from_taskins(taskins)

# Assert
Expand All @@ -337,6 +340,7 @@ def test_message_to_and_from_taskres() -> None:
run_id=0,
task_id="",
group_id="",
node_id=metadata.node_id,
ttl=metadata.ttl,
task_type=metadata.task_type,
),
Expand All @@ -348,6 +352,7 @@ def test_message_to_and_from_taskres() -> None:
taskres.run_id = metadata.run_id
taskres.task_id = metadata.task_id
taskres.group_id = metadata.group_id
taskres.task.consumer.node_id = metadata.node_id
deserialized = message_from_taskres(taskres)

# Assert
Expand Down
Loading