Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ttl a float #3166

Merged
merged 16 commits into from
Mar 26, 2024
6 changes: 3 additions & 3 deletions examples/app-pytorch/client_low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ def hello_world_mod(msg, ctx, call_next) -> Message:
@app.train()
def train(msg: Message, ctx: Context):
print("`train` is not implemented, echoing original message")
return msg.create_reply(msg.content, ttl="")
return msg.create_reply(msg.content)


@app.evaluate()
def eval(msg: Message, ctx: Context):
print("`evaluate` is not implemented, echoing original message")
return msg.create_reply(msg.content, ttl="")
return msg.create_reply(msg.content)


@app.query()
def query(msg: Message, ctx: Context):
print("`query` is not implemented, echoing original message")
return msg.create_reply(msg.content, ttl="")
return msg.create_reply(msg.content)
3 changes: 2 additions & 1 deletion examples/app-pytorch/server_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Message,
MessageType,
Metrics,
DEFAULT_TTL,
)
from flwr.common.recordset_compat import fitins_to_recordset, recordset_to_fitres
from flwr.server import Driver, History
Expand Down Expand Up @@ -89,7 +90,7 @@ def main(driver: Driver, context: Context) -> None:
message_type=MessageType.TRAIN,
dst_node_id=node_id,
group_id=str(server_round),
ttl="",
ttl=DEFAULT_TTL,
)
messages.append(message)

Expand Down
12 changes: 10 additions & 2 deletions examples/app-pytorch/server_low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import time

import flwr as fl
from flwr.common import Context, NDArrays, Message, MessageType, Metrics, RecordSet
from flwr.common import (
Context,
NDArrays,
Message,
MessageType,
Metrics,
RecordSet,
DEFAULT_TTL,
)
from flwr.server import Driver


Expand All @@ -30,7 +38,7 @@ def main(driver: Driver, context: Context) -> None:
message_type=MessageType.TRAIN,
dst_node_id=node_id,
group_id=str(server_round),
ttl="",
ttl=DEFAULT_TTL,
)
messages.append(message)

Expand Down
2 changes: 1 addition & 1 deletion src/proto/flwr/proto/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ message Task {
Node consumer = 2;
string created_at = 3;
string delivered_at = 4;
string ttl = 5;
double ttl = 5;
repeated string ancestry = 6;
string task_type = 7;
RecordSet recordset = 8;
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/client/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]:
>>> def train(message: Message, context: Context) -> Message:
>>> print("ClientApp training running")
>>> # Create and return an echo reply message
>>> return message.create_reply(content=message.content(), ttl="")
>>> return message.create_reply(content=message.content())
"""

def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable:
Expand Down Expand Up @@ -143,7 +143,7 @@ def evaluate(self) -> Callable[[ClientAppCallable], ClientAppCallable]:
>>> def evaluate(message: Message, context: Context) -> Message:
>>> print("ClientApp evaluation running")
>>> # Create and return an echo reply message
>>> return message.create_reply(content=message.content(), ttl="")
>>> return message.create_reply(content=message.content())
"""

def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable:
Expand Down Expand Up @@ -171,7 +171,7 @@ def query(self) -> Callable[[ClientAppCallable], ClientAppCallable]:
>>> def query(message: Message, context: Context) -> Message:
>>> print("ClientApp query running")
>>> # Create and return an echo reply message
>>> return message.create_reply(content=message.content(), ttl="")
>>> return message.create_reply(content=message.content())
"""

def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable:
Expand Down Expand Up @@ -218,7 +218,7 @@ def _registration_error(fn_name: str) -> ValueError:
>>> print("ClientApp {fn_name} running")
>>> # Create and return an echo reply message
>>> return message.create_reply(
>>> content=message.content(), ttl=""
>>> content=message.content()
>>> )
""",
)
3 changes: 2 additions & 1 deletion src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Callable, Iterator, Optional, Tuple, Union, cast

from flwr.common import (
DEFAULT_TTL,
GRPC_MAX_MESSAGE_LENGTH,
ConfigsRecord,
Message,
Expand Down Expand Up @@ -180,7 +181,7 @@ def receive() -> Message:
dst_node_id=0,
reply_to_message="",
group_id="",
ttl="",
ttl=DEFAULT_TTL,
message_type=message_type,
),
content=recordset,
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import grpc

from flwr.common import ConfigsRecord, Message, Metadata, RecordSet
from flwr.common import DEFAULT_TTL, ConfigsRecord, Message, Metadata, RecordSet
from flwr.common import recordset_compat as compat
from flwr.common.constant import MessageTypeLegacy
from flwr.common.retry_invoker import RetryInvoker, exponential
Expand All @@ -50,7 +50,7 @@
dst_node_id=0,
reply_to_message="",
group_id="",
ttl="",
ttl=DEFAULT_TTL,
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
content=compat.getpropertiesres_to_recordset(
Expand All @@ -65,7 +65,7 @@
dst_node_id=0,
reply_to_message="",
group_id="",
ttl="",
ttl=DEFAULT_TTL,
message_type="reconnect",
),
content=RecordSet(configs_records={"config": ConfigsRecord({"reason": 0})}),
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
reason = cast(int, disconnect_msg.disconnect_res.reason)
recordset = RecordSet()
recordset.configs_records["config"] = ConfigsRecord({"reason": reason})
out_message = message.create_reply(recordset, ttl="")
out_message = message.create_reply(recordset)
# Return TaskRes and sleep duration
return out_message, sleep_duration

Expand Down Expand Up @@ -143,7 +143,7 @@ def handle_legacy_message_from_msgtype(
raise ValueError(f"Invalid message type: {message_type}")

# Return Message
return message.create_reply(out_recordset, ttl="")
return message.create_reply(out_recordset)


def _reconnect(
Expand Down
13 changes: 7 additions & 6 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flwr.client import Client
from flwr.client.typing import ClientFn
from flwr.common import (
DEFAULT_TTL,
Code,
Context,
EvaluateIns,
Expand Down Expand Up @@ -131,7 +132,7 @@ def test_client_without_get_properties() -> None:
src_node_id=0,
dst_node_id=1123,
reply_to_message="",
ttl="",
ttl=DEFAULT_TTL,
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
content=recordset,
Expand Down Expand Up @@ -161,7 +162,7 @@ def test_client_without_get_properties() -> None:
src_node_id=1123,
dst_node_id=0,
reply_to_message=message.metadata.message_id,
ttl="",
ttl=DEFAULT_TTL,
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
content=expected_rs,
Expand All @@ -184,7 +185,7 @@ def test_client_with_get_properties() -> None:
src_node_id=0,
dst_node_id=1123,
reply_to_message="",
ttl="",
ttl=DEFAULT_TTL,
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
content=recordset,
Expand Down Expand Up @@ -214,7 +215,7 @@ def test_client_with_get_properties() -> None:
src_node_id=1123,
dst_node_id=0,
reply_to_message=message.metadata.message_id,
ttl="",
ttl=DEFAULT_TTL,
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
content=expected_rs,
Expand All @@ -237,7 +238,7 @@ def setUp(self) -> None:
dst_node_id=20,
reply_to_message="",
group_id="group1",
ttl="60",
ttl=DEFAULT_TTL,
message_type="mock",
)
self.valid_out_metadata = Metadata(
Expand All @@ -247,7 +248,7 @@ def setUp(self) -> None:
dst_node_id=10,
reply_to_message="qwerty",
group_id="group1",
ttl="60",
ttl=DEFAULT_TTL,
message_type="mock",
)
self.common_content = RecordSet()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def secaggplus_mod(

# Return message
out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False)
return msg.create_reply(out_content, ttl="")
return msg.create_reply(out_content)


def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
from typing import Callable, Dict, List

from flwr.client.mod import make_ffn
from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet
from flwr.common import (
DEFAULT_TTL,
ConfigsRecord,
Context,
Message,
Metadata,
RecordSet,
)
from flwr.common.constant import MessageType
from flwr.common.secure_aggregation.secaggplus_constants import (
RECORD_KEY_CONFIGS,
Expand All @@ -38,7 +45,7 @@ def get_test_handler(
"""."""

def empty_ffn(_msg: Message, _2: Context) -> Message:
return _msg.create_reply(RecordSet(), ttl="")
return _msg.create_reply(RecordSet())

app = make_ffn(empty_ffn, [secaggplus_mod])

Expand All @@ -51,7 +58,7 @@ def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord:
dst_node_id=123,
reply_to_message="",
group_id="",
ttl="",
ttl=DEFAULT_TTL,
message_type=MessageType.TRAIN,
),
content=RecordSet(
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/client/mod/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from flwr.client.typing import ClientAppCallable, Mod
from flwr.common import (
DEFAULT_TTL,
ConfigsRecord,
Context,
Message,
Expand Down Expand Up @@ -84,7 +85,7 @@ def _get_dummy_flower_message() -> Message:
src_node_id=0,
dst_node_id=0,
reply_to_message="",
ttl="",
ttl=DEFAULT_TTL,
message_type="mock",
),
)
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .grpc import GRPC_MAX_MESSAGE_LENGTH
from .logger import configure as configure
from .logger import log as log
from .message import DEFAULT_TTL
from .message import Error as Error
from .message import Message as Message
from .message import Metadata as Metadata
Expand Down Expand Up @@ -87,6 +88,7 @@
"Message",
"MessageType",
"MessageTypeLegacy",
"DEFAULT_TTL",
"Metadata",
"Metrics",
"MetricsAggregationFn",
Expand Down
Loading