diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index 1cfb77135d5a..1caaad88a0da 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -36,10 +36,10 @@ service Driver { // CreateWorkload message CreateWorkloadRequest {} -message CreateWorkloadResponse { string workload_id = 1; } +message CreateWorkloadResponse { uint64 workload_id = 1; } // GetNodes messages -message GetNodesRequest { string workload_id = 1; } +message GetNodesRequest { uint64 workload_id = 1; } message GetNodesResponse { repeated Node nodes = 1; } // PushTaskIns messages diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 29e07641bb1c..d87fb39c2637 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -36,14 +36,14 @@ message Task { message TaskIns { string task_id = 1; string group_id = 2; - string workload_id = 3; + uint64 workload_id = 3; Task task = 4; } message TaskRes { string task_id = 1; string group_id = 2; - string workload_id = 3; + uint64 workload_id = 3; Task task = 4; } diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index eda869d3a326..cc64ec9a268a 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -117,7 +117,7 @@ def receive() -> TaskIns: return TaskIns( task_id=str(uuid.uuid4()), group_id="", - workload_id="", + workload_id=0, task=Task( producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 9b26a9bd5ca0..f50923450f62 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -76,7 +76,7 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> Tuple[TaskRes, int, bool]: task_res = TaskRes( task_id="", group_id="", - workload_id="", + workload_id=0, task=Task( ancestry=[], sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 31cbb00edf63..1fc2269ad75d 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -120,7 +120,7 @@ def test_client_without_get_properties() -> None: task_ins: TaskIns = TaskIns( task_id=str(uuid.uuid4()), group_id="", - workload_id="", + workload_id=0, task=Task( producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), @@ -146,7 +146,7 @@ def test_client_without_get_properties() -> None: TaskRes( task_id=str(uuid.uuid4()), group_id="", - workload_id="", + workload_id=0, ) ) # pylint: disable=no-member @@ -183,7 +183,7 @@ def test_client_with_get_properties() -> None: task_ins = TaskIns( task_id=str(uuid.uuid4()), group_id="", - workload_id="", + workload_id=0, task=Task( producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), @@ -209,7 +209,7 @@ def test_client_with_get_properties() -> None: TaskRes( task_id=str(uuid.uuid4()), group_id="", - workload_id="", + workload_id=0, ) ) # pylint: disable=no-member diff --git a/src/py/flwr/client/message_handler/task_handler.py b/src/py/flwr/client/message_handler/task_handler.py index 03688c52ac8f..b48c7433c1da 100644 --- a/src/py/flwr/client/message_handler/task_handler.py +++ b/src/py/flwr/client/message_handler/task_handler.py @@ -129,7 +129,7 @@ def wrap_client_message_in_task_res(client_message: ClientMessage) -> TaskRes: return TaskRes( task_id="", group_id="", - workload_id="", + workload_id=0, task=Task(ancestry=[], legacy_client_message=client_message), ) diff --git a/src/py/flwr/client/message_handler/task_handler_test.py b/src/py/flwr/client/message_handler/task_handler_test.py index 347b9ad32c4b..e1b7fac69d24 100644 --- a/src/py/flwr/client/message_handler/task_handler_test.py +++ b/src/py/flwr/client/message_handler/task_handler_test.py @@ -92,7 +92,7 @@ def test_validate_task_res() -> None: assert not validate_task_res(task_res) task_res.Clear() - task_res.workload_id = "123" + task_res.workload_id = 61016 assert not validate_task_res(task_res) task_res.Clear() diff --git a/src/py/flwr/driver/app_test.py b/src/py/flwr/driver/app_test.py index 792bd84b6106..4fcd924f8432 100644 --- a/src/py/flwr/driver/app_test.py +++ b/src/py/flwr/driver/app_test.py @@ -43,7 +43,7 @@ def test_simple_client_manager_update(self) -> None: ] driver = MagicMock() driver.stub = "driver stub" - driver.create_workload.return_value = CreateWorkloadResponse(workload_id="1") + driver.create_workload.return_value = CreateWorkloadResponse(workload_id=1) driver.get_nodes.return_value = GetNodesResponse(nodes=expected_nodes) client_manager = SimpleClientManager() lock = threading.Lock() diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py index cd5d36cafdd7..deb472458a15 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/driver/driver_client_proxy.py @@ -31,7 +31,7 @@ class DriverClientProxy(ClientProxy): """Flower client proxy which delegates work using the Driver API.""" - def __init__(self, node_id: int, driver: Driver, anonymous: bool, workload_id: str): + def __init__(self, node_id: int, driver: Driver, anonymous: bool, workload_id: int): super().__init__(str(node_id)) self.node_id = node_id self.driver = driver diff --git a/src/py/flwr/driver/driver_client_proxy_test.py b/src/py/flwr/driver/driver_client_proxy_test.py index fa2a29e88687..f413b8d8d99d 100644 --- a/src/py/flwr/driver/driver_client_proxy_test.py +++ b/src/py/flwr/driver/driver_client_proxy_test.py @@ -52,7 +52,7 @@ def test_get_properties(self) -> None: task_pb2.TaskRes( task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", - workload_id="", + workload_id=0, task=task_pb2.Task( legacy_client_message=ClientMessage( get_properties_res=ClientMessage.GetPropertiesRes( @@ -64,7 +64,7 @@ def test_get_properties(self) -> None: ] ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id="" + node_id=1, driver=self.driver, anonymous=True, workload_id=0 ) request_properties: Config = {"tensor_type": "str"} ins: flwr.common.GetPropertiesIns = flwr.common.GetPropertiesIns( @@ -88,7 +88,7 @@ def test_get_parameters(self) -> None: task_pb2.TaskRes( task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", - workload_id="", + workload_id=0, task=task_pb2.Task( legacy_client_message=ClientMessage( get_parameters_res=ClientMessage.GetParametersRes( @@ -100,7 +100,7 @@ def test_get_parameters(self) -> None: ] ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id="" + node_id=1, driver=self.driver, anonymous=True, workload_id=0 ) get_parameters_ins = GetParametersIns(config={}) @@ -123,7 +123,7 @@ def test_fit(self) -> None: task_pb2.TaskRes( task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", - workload_id="", + workload_id=0, task=task_pb2.Task( legacy_client_message=ClientMessage( fit_res=ClientMessage.FitRes( @@ -136,7 +136,7 @@ def test_fit(self) -> None: ] ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id="" + node_id=1, driver=self.driver, anonymous=True, workload_id=0 ) parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))]) ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {}) @@ -160,7 +160,7 @@ def test_evaluate(self) -> None: task_pb2.TaskRes( task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", - workload_id="", + workload_id=0, task=task_pb2.Task( legacy_client_message=ClientMessage( evaluate_res=ClientMessage.EvaluateRes( @@ -172,7 +172,7 @@ def test_evaluate(self) -> None: ] ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id="" + node_id=1, driver=self.driver, anonymous=True, workload_id=0 ) parameters = flwr.common.Parameters(tensors=[], tensor_type="np") evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {}) diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index c18d9c593c28..6ac066d7eab3 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -16,7 +16,7 @@ from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x17\n\x15\x43reateWorkloadRequest\"-\n\x16\x43reateWorkloadResponse\x12\x13\n\x0bworkload_id\x18\x01 \x01(\t\"&\n\x0fGetNodesRequest\x12\x13\n\x0bworkload_id\x18\x01 \x01(\t\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xd0\x02\n\x06\x44river\x12Y\n\x0e\x43reateWorkload\x12!.flwr.proto.CreateWorkloadRequest\x1a\".flwr.proto.CreateWorkloadResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x17\n\x15\x43reateWorkloadRequest\"-\n\x16\x43reateWorkloadResponse\x12\x13\n\x0bworkload_id\x18\x01 \x01(\x04\"&\n\x0fGetNodesRequest\x12\x13\n\x0bworkload_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xd0\x02\n\x06\x44river\x12Y\n\x0e\x43reateWorkload\x12!.flwr.proto.CreateWorkloadRequest\x1a\".flwr.proto.CreateWorkloadResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x62\x06proto3') diff --git a/src/py/flwr/proto/driver_pb2.pyi b/src/py/flwr/proto/driver_pb2.pyi index 486bddb0f76f..8b940972cb6d 100644 --- a/src/py/flwr/proto/driver_pb2.pyi +++ b/src/py/flwr/proto/driver_pb2.pyi @@ -23,10 +23,10 @@ global___CreateWorkloadRequest = CreateWorkloadRequest class CreateWorkloadResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor WORKLOAD_ID_FIELD_NUMBER: builtins.int - workload_id: typing.Text + workload_id: builtins.int def __init__(self, *, - workload_id: typing.Text = ..., + workload_id: builtins.int = ..., ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ... global___CreateWorkloadResponse = CreateWorkloadResponse @@ -35,10 +35,10 @@ class GetNodesRequest(google.protobuf.message.Message): """GetNodes messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor WORKLOAD_ID_FIELD_NUMBER: builtins.int - workload_id: typing.Text + workload_id: builtins.int def __init__(self, *, - workload_id: typing.Text = ..., + workload_id: builtins.int = ..., ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ... global___GetNodesRequest = GetNodesRequest diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 42d3952f61df..69bad48d0d37 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -16,7 +16,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xbe\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12)\n\x02sa\x18\x07 \x01(\x0b\x32\x1d.flwr.proto.SecureAggregation\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\"a\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\t\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"a\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\t\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xf3\x03\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12\x33\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x1c.flwr.proto.Value.DoubleListH\x00\x12\x33\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x1c.flwr.proto.Value.Sint64ListH\x00\x12/\n\tbool_list\x18\x17 \x01(\x0b\x32\x1a.flwr.proto.Value.BoolListH\x00\x12\x33\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x1c.flwr.proto.Value.StringListH\x00\x12\x31\n\nbytes_list\x18\x19 \x01(\x0b\x32\x1b.flwr.proto.Value.BytesListH\x00\x1a\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\x1a\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\x1a\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\x1a\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\x1a\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xbe\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12)\n\x02sa\x18\x07 \x01(\x0b\x32\x1d.flwr.proto.SecureAggregation\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\"a\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\x04\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"a\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\x04\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xf3\x03\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12\x33\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x1c.flwr.proto.Value.DoubleListH\x00\x12\x33\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x1c.flwr.proto.Value.Sint64ListH\x00\x12/\n\tbool_list\x18\x17 \x01(\x0b\x32\x1a.flwr.proto.Value.BoolListH\x00\x12\x33\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x1c.flwr.proto.Value.StringListH\x00\x12\x31\n\nbytes_list\x18\x19 \x01(\x0b\x32\x1b.flwr.proto.Value.BytesListH\x00\x1a\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\x1a\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\x1a\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\x1a\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\x1a\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index dcd4686944bc..7cf96cb61edf 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -63,14 +63,14 @@ class TaskIns(google.protobuf.message.Message): TASK_FIELD_NUMBER: builtins.int task_id: typing.Text group_id: typing.Text - workload_id: typing.Text + workload_id: builtins.int @property def task(self) -> global___Task: ... def __init__(self, *, task_id: typing.Text = ..., group_id: typing.Text = ..., - workload_id: typing.Text = ..., + workload_id: builtins.int = ..., task: typing.Optional[global___Task] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["task",b"task"]) -> builtins.bool: ... @@ -85,14 +85,14 @@ class TaskRes(google.protobuf.message.Message): TASK_FIELD_NUMBER: builtins.int task_id: typing.Text group_id: typing.Text - workload_id: typing.Text + workload_id: builtins.int @property def task(self) -> global___Task: ... def __init__(self, *, task_id: typing.Text = ..., group_id: typing.Text = ..., - workload_id: typing.Text = ..., + workload_id: builtins.int = ..., task: typing.Optional[global___Task] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["task",b"task"]) -> builtins.bool: ... diff --git a/src/py/flwr/server/fleet/message_handler/message_handler_test.py b/src/py/flwr/server/fleet/message_handler/message_handler_test.py index 10f678e3479e..da92b267f082 100644 --- a/src/py/flwr/server/fleet/message_handler/message_handler_test.py +++ b/src/py/flwr/server/fleet/message_handler/message_handler_test.py @@ -109,7 +109,7 @@ def test_push_task_res() -> None: TaskRes( task_id="", group_id="", - workload_id="", + workload_id=0, task=Task(), ), ], diff --git a/src/py/flwr/server/state/in_memory_state.py b/src/py/flwr/server/state/in_memory_state.py index 075ba2cf304d..d6292571cd6d 100644 --- a/src/py/flwr/server/state/in_memory_state.py +++ b/src/py/flwr/server/state/in_memory_state.py @@ -32,7 +32,7 @@ class InMemoryState(State): def __init__(self) -> None: self.node_ids: Set[int] = set() - self.workload_ids: Set[str] = set() + self.workload_ids: Set[int] = set() self.task_ins_store: Dict[UUID, TaskIns] = {} self.task_res_store: Dict[UUID, TaskRes] = {} @@ -194,7 +194,7 @@ def unregister_node(self, node_id: int) -> None: raise ValueError(f"Node {node_id} is not registered") self.node_ids.remove(node_id) - def get_nodes(self, workload_id: str) -> Set[int]: + def get_nodes(self, workload_id: int) -> Set[int]: """Return all available client nodes. Constraints @@ -206,14 +206,13 @@ def get_nodes(self, workload_id: str) -> Set[int]: return set() return self.node_ids - def create_workload(self) -> str: + def create_workload(self) -> int: """Create one workload.""" - # String representation of random integer from 0 to 9223372036854775807 - random_workload_id: int = random.randrange(9223372036854775808) - workload_id = str(random_workload_id) + # Sample random integer from 0 to 9223372036854775807 + workload_id: int = random.randrange(9223372036854775808) if workload_id not in self.workload_ids: self.workload_ids.add(workload_id) return workload_id log(ERROR, "Unexpected workload creation failure.") - return "" + return 0 diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/state/sqlite_state.py index e971c11da2f5..0c853409b844 100644 --- a/src/py/flwr/server/state/sqlite_state.py +++ b/src/py/flwr/server/state/sqlite_state.py @@ -39,7 +39,7 @@ SQL_CREATE_TABLE_WORKLOAD = """ CREATE TABLE IF NOT EXISTS workload( - workload_id TEXT UNIQUE + workload_id INTEGER UNIQUE ); """ @@ -47,7 +47,7 @@ CREATE TABLE IF NOT EXISTS task_ins( task_id TEXT UNIQUE, group_id TEXT, - workload_id TEXT, + workload_id INTEGER, producer_anonymous BOOLEAN, producer_node_id INTEGER, consumer_anonymous BOOLEAN, @@ -67,7 +67,7 @@ CREATE TABLE IF NOT EXISTS task_res( task_id TEXT UNIQUE, group_id TEXT, - workload_id TEXT, + workload_id INTEGER, producer_anonymous BOOLEAN, producer_node_id INTEGER, consumer_anonymous BOOLEAN, @@ -479,7 +479,7 @@ def unregister_node(self, node_id: int) -> None: query = "DELETE FROM node WHERE node_id = :node_id;" self.query(query, {"node_id": node_id}) - def get_nodes(self, workload_id: str) -> Set[int]: + def get_nodes(self, workload_id: int) -> Set[int]: """Retrieve all currently stored node IDs as a set. Constraints @@ -498,11 +498,10 @@ def get_nodes(self, workload_id: str) -> Set[int]: result: Set[int] = {row["node_id"] for row in rows} return result - def create_workload(self) -> str: + def create_workload(self) -> int: """Create one workload and store it in state.""" - # String representation of random integer from 0 to 9223372036854775807 - random_workload_id: int = random.randrange(9223372036854775808) - workload_id = str(random_workload_id) + # Sample random integer from 0 to 9223372036854775807 + workload_id: int = random.randrange(9223372036854775808) # Check conflicts query = "SELECT COUNT(*) FROM workload WHERE workload_id = ?;" @@ -512,7 +511,7 @@ def create_workload(self) -> str: self.query(query, {"workload_id": workload_id}) return workload_id log(ERROR, "Unexpected workload creation failure.") - return "" + return 0 def dict_factory( diff --git a/src/py/flwr/server/state/sqlite_state_test.py b/src/py/flwr/server/state/sqlite_state_test.py index e3bb72e34118..b9c0df9ed134 100644 --- a/src/py/flwr/server/state/sqlite_state_test.py +++ b/src/py/flwr/server/state/sqlite_state_test.py @@ -27,7 +27,7 @@ class SqliteStateTest(unittest.TestCase): def test_ins_res_to_dict(self) -> None: """Check if all required keys are included in return value.""" # Prepare - ins_res = create_task_ins(consumer_node_id=1, anonymous=True, workload_id="") + ins_res = create_task_ins(consumer_node_id=1, anonymous=True, workload_id=0) expected_keys = [ "task_id", "group_id", diff --git a/src/py/flwr/server/state/state.py b/src/py/flwr/server/state/state.py index cfd68c589b6e..a0b9e663f637 100644 --- a/src/py/flwr/server/state/state.py +++ b/src/py/flwr/server/state/state.py @@ -140,7 +140,7 @@ def unregister_node(self, node_id: int) -> None: """Remove `node_id` from state.""" @abc.abstractmethod - def get_nodes(self, workload_id: str) -> Set[int]: + def get_nodes(self, workload_id: int) -> Set[int]: """Retrieve all currently stored node IDs as a set. Constraints @@ -150,5 +150,5 @@ def get_nodes(self, workload_id: str) -> Set[int]: """ @abc.abstractmethod - def create_workload(self) -> str: + def create_workload(self) -> int: """Create one workload.""" diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py index e80bd55352ed..bc3015ba5cc2 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/state/state_test.py @@ -283,7 +283,7 @@ def test_task_ins_store_invalid_workload_id_and_fail(self) -> None: # Prepare state: State = self.state_factory() task_ins = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id="I'm invalid" + consumer_node_id=0, anonymous=True, workload_id=61016 ) # Execute @@ -362,7 +362,7 @@ def test_get_nodes_invalid_workload_id(self) -> None: # Prepare state: State = self.state_factory() state.create_workload() - invalid_workload_id = "" + invalid_workload_id = 61016 node_id = 2 # Execute @@ -420,7 +420,7 @@ def test_num_task_res(self) -> None: def create_task_ins( consumer_node_id: int, anonymous: bool, - workload_id: str, + workload_id: int, delivered_at: str = "", ) -> TaskIns: """Create a TaskIns for testing.""" @@ -448,7 +448,7 @@ def create_task_res( producer_node_id: int, anonymous: bool, ancestry: List[str], - workload_id: str, + workload_id: int, ) -> TaskRes: """Create a TaskRes for testing.""" task_res = TaskRes( diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index 533e3a236572..54840731048f 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -135,7 +135,7 @@ def create_task_ins( task = TaskIns( task_id="", group_id="", - workload_id="", + workload_id=0, task=Task( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), @@ -162,7 +162,7 @@ def create_task_res( task_res = TaskRes( task_id="", group_id="", - workload_id="", + workload_id=0, task=Task( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True),