Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn once if TTL is set #3195

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/py/flwr/common/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import time
import warnings
from dataclasses import dataclass

from .record import RecordSet
Expand Down Expand Up @@ -311,6 +312,13 @@ def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:

ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
"""
if ttl:
warnings.warn(
"A custom TTL was set, but note that the SuperLink does not enforce "
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
"version of Flower.",
stacklevel=2,
)
# If no TTL passed, use default for message creation (will update after
# message creation)
ttl_ = DEFAULT_TTL if ttl is None else ttl
Expand Down Expand Up @@ -349,6 +357,13 @@ def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:
Message
A new `Message` instance representing the reply.
"""
if ttl:
warnings.warn(
"A custom TTL was set, but note that the SuperLink does not enforce "
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
"version of Flower.",
stacklevel=2,
)
# If no TTL passed, use default for message creation (will update after
# message creation)
ttl_ = DEFAULT_TTL if ttl is None else ttl
Expand Down
20 changes: 15 additions & 5 deletions src/py/flwr/server/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# ==============================================================================
"""Flower driver service client."""


import time
import warnings
from typing import Iterable, List, Optional, Tuple

from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
Expand Down Expand Up @@ -91,7 +91,7 @@ def create_message( # pylint: disable=too-many-arguments
message_type: str,
dst_node_id: int,
group_id: str,
ttl: float = DEFAULT_TTL,
ttl: Optional[float] = None,
) -> Message:
"""Create a new message with specified parameters.

Expand All @@ -111,25 +111,35 @@ def create_message( # pylint: disable=too-many-arguments
group_id : str
The ID of the group to which this message is associated. In some settings,
this is used as the FL round.
ttl : float (default: common.DEFAULT_TTL)
ttl : Optional[float] (default: None)
Time-to-live for the round trip of this message, i.e., the time from sending
this message to receiving a reply. It specifies in seconds the duration for
which the message and its potential reply are considered valid.
which the message and its potential reply are considered valid. If unset,
the default TTL (i.e., `common.DEFAULT_TTL`) will be used.

Returns
-------
message : Message
A new `Message` instance with the specified content and metadata.
"""
_, run_id = self._get_grpc_driver_and_run_id()
if ttl:
warnings.warn(
"A custom TTL was set, but note that the SuperLink does not enforce "
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
"version of Flower.",
stacklevel=2,
)

ttl_ = DEFAULT_TTL if ttl is None else ttl
metadata = Metadata(
run_id=run_id,
message_id="", # Will be set by the server
src_node_id=self.node.node_id,
dst_node_id=dst_node_id,
reply_to_message="",
group_id=group_id,
ttl=ttl,
ttl=ttl_,
message_type=message_type,
)
return Message(metadata=metadata, content=content)
Expand Down
5 changes: 1 addition & 4 deletions src/py/flwr/server/workflow/default_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Optional, cast

import flwr.common.recordset_compat as compat
from flwr.common import DEFAULT_TTL, ConfigsRecord, Context, GetParametersIns, log
from flwr.common import ConfigsRecord, Context, GetParametersIns, log
from flwr.common.constant import MessageType, MessageTypeLegacy

from ..compat.app_utils import start_update_client_manager_thread
Expand Down Expand Up @@ -127,7 +127,6 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
message_type=MessageTypeLegacy.GET_PARAMETERS,
dst_node_id=random_client.node_id,
group_id="0",
ttl=DEFAULT_TTL,
)
]
)
Expand Down Expand Up @@ -226,7 +225,6 @@ def default_fit_workflow( # pylint: disable=R0914
message_type=MessageType.TRAIN,
dst_node_id=proxy.node_id,
group_id=str(current_round),
ttl=DEFAULT_TTL,
)
for proxy, fitins in client_instructions
]
Expand Down Expand Up @@ -306,7 +304,6 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
message_type=MessageType.EVALUATE,
dst_node_id=proxy.node_id,
group_id=str(current_round),
ttl=DEFAULT_TTL,
)
for proxy, evalins in client_instructions
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import flwr.common.recordset_compat as compat
from flwr.common import (
DEFAULT_TTL,
ConfigsRecord,
Context,
FitRes,
Expand Down Expand Up @@ -374,7 +373,6 @@ def make(nid: int) -> Message:
message_type=MessageType.TRAIN,
dst_node_id=nid,
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
ttl=DEFAULT_TTL,
)

log(
Expand Down Expand Up @@ -422,7 +420,6 @@ def make(nid: int) -> Message:
message_type=MessageType.TRAIN,
dst_node_id=nid,
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
ttl=DEFAULT_TTL,
)

# Broadcast public keys to clients and receive secret key shares
Expand Down Expand Up @@ -493,7 +490,6 @@ def make(nid: int) -> Message:
message_type=MessageType.TRAIN,
dst_node_id=nid,
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
ttl=DEFAULT_TTL,
)

log(
Expand Down Expand Up @@ -564,7 +560,6 @@ def make(nid: int) -> Message:
message_type=MessageType.TRAIN,
dst_node_id=nid,
group_id=str(current_round),
ttl=DEFAULT_TTL,
)

log(
Expand Down