Skip to content

Commit

Permalink
Change workload_id type to uint64 (#2413)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Oct 7, 2023
1 parent e1ba808 commit e7977b5
Show file tree
Hide file tree
Showing 21 changed files with 56 additions and 58 deletions.
4 changes: 2 additions & 2 deletions src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ service Driver {

// CreateWorkload
message CreateWorkloadRequest {}
message CreateWorkloadResponse { string workload_id = 1; }
message CreateWorkloadResponse { uint64 workload_id = 1; }

// GetNodes messages
message GetNodesRequest { string workload_id = 1; }
message GetNodesRequest { uint64 workload_id = 1; }
message GetNodesResponse { repeated Node nodes = 1; }

// PushTaskIns messages
Expand Down
4 changes: 2 additions & 2 deletions src/proto/flwr/proto/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ message Task {
message TaskIns {
string task_id = 1;
string group_id = 2;
string workload_id = 3;
uint64 workload_id = 3;
Task task = 4;
}

message TaskRes {
string task_id = 1;
string group_id = 2;
string workload_id = 3;
uint64 workload_id = 3;
Task task = 4;
}

Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def receive() -> TaskIns:
return TaskIns(
task_id=str(uuid.uuid4()),
group_id="",
workload_id="",
workload_id=0,
task=Task(
producer=Node(node_id=0, anonymous=True),
consumer=Node(node_id=0, anonymous=True),
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> Tuple[TaskRes, int, bool]:
task_res = TaskRes(
task_id="",
group_id="",
workload_id="",
workload_id=0,
task=Task(
ancestry=[],
sa=SecureAggregation(named_values=serde.named_values_to_proto(res)),
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_client_without_get_properties() -> None:
task_ins: TaskIns = TaskIns(
task_id=str(uuid.uuid4()),
group_id="",
workload_id="",
workload_id=0,
task=Task(
producer=Node(node_id=0, anonymous=True),
consumer=Node(node_id=0, anonymous=True),
Expand All @@ -146,7 +146,7 @@ def test_client_without_get_properties() -> None:
TaskRes(
task_id=str(uuid.uuid4()),
group_id="",
workload_id="",
workload_id=0,
)
)
# pylint: disable=no-member
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_client_with_get_properties() -> None:
task_ins = TaskIns(
task_id=str(uuid.uuid4()),
group_id="",
workload_id="",
workload_id=0,
task=Task(
producer=Node(node_id=0, anonymous=True),
consumer=Node(node_id=0, anonymous=True),
Expand All @@ -209,7 +209,7 @@ def test_client_with_get_properties() -> None:
TaskRes(
task_id=str(uuid.uuid4()),
group_id="",
workload_id="",
workload_id=0,
)
)
# pylint: disable=no-member
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/message_handler/task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def wrap_client_message_in_task_res(client_message: ClientMessage) -> TaskRes:
return TaskRes(
task_id="",
group_id="",
workload_id="",
workload_id=0,
task=Task(ancestry=[], legacy_client_message=client_message),
)

Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/message_handler/task_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_validate_task_res() -> None:
assert not validate_task_res(task_res)

task_res.Clear()
task_res.workload_id = "123"
task_res.workload_id = 61016
assert not validate_task_res(task_res)

task_res.Clear()
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/driver/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_simple_client_manager_update(self) -> None:
]
driver = MagicMock()
driver.stub = "driver stub"
driver.create_workload.return_value = CreateWorkloadResponse(workload_id="1")
driver.create_workload.return_value = CreateWorkloadResponse(workload_id=1)
driver.get_nodes.return_value = GetNodesResponse(nodes=expected_nodes)
client_manager = SimpleClientManager()
lock = threading.Lock()
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/driver/driver_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
class DriverClientProxy(ClientProxy):
"""Flower client proxy which delegates work using the Driver API."""

def __init__(self, node_id: int, driver: Driver, anonymous: bool, workload_id: str):
def __init__(self, node_id: int, driver: Driver, anonymous: bool, workload_id: int):
super().__init__(str(node_id))
self.node_id = node_id
self.driver = driver
Expand Down
16 changes: 8 additions & 8 deletions src/py/flwr/driver/driver_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_get_properties(self) -> None:
task_pb2.TaskRes(
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
workload_id="",
workload_id=0,
task=task_pb2.Task(
legacy_client_message=ClientMessage(
get_properties_res=ClientMessage.GetPropertiesRes(
Expand All @@ -64,7 +64,7 @@ def test_get_properties(self) -> None:
]
)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=""
node_id=1, driver=self.driver, anonymous=True, workload_id=0
)
request_properties: Config = {"tensor_type": "str"}
ins: flwr.common.GetPropertiesIns = flwr.common.GetPropertiesIns(
Expand All @@ -88,7 +88,7 @@ def test_get_parameters(self) -> None:
task_pb2.TaskRes(
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
workload_id="",
workload_id=0,
task=task_pb2.Task(
legacy_client_message=ClientMessage(
get_parameters_res=ClientMessage.GetParametersRes(
Expand All @@ -100,7 +100,7 @@ def test_get_parameters(self) -> None:
]
)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=""
node_id=1, driver=self.driver, anonymous=True, workload_id=0
)
get_parameters_ins = GetParametersIns(config={})

Expand All @@ -123,7 +123,7 @@ def test_fit(self) -> None:
task_pb2.TaskRes(
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
workload_id="",
workload_id=0,
task=task_pb2.Task(
legacy_client_message=ClientMessage(
fit_res=ClientMessage.FitRes(
Expand All @@ -136,7 +136,7 @@ def test_fit(self) -> None:
]
)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=""
node_id=1, driver=self.driver, anonymous=True, workload_id=0
)
parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))])
ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {})
Expand All @@ -160,7 +160,7 @@ def test_evaluate(self) -> None:
task_pb2.TaskRes(
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
workload_id="",
workload_id=0,
task=task_pb2.Task(
legacy_client_message=ClientMessage(
evaluate_res=ClientMessage.EvaluateRes(
Expand All @@ -172,7 +172,7 @@ def test_evaluate(self) -> None:
]
)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=""
node_id=1, driver=self.driver, anonymous=True, workload_id=0
)
parameters = flwr.common.Parameters(tensors=[], tensor_type="np")
evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {})
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/proto/driver_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions src/py/flwr/proto/driver_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ global___CreateWorkloadRequest = CreateWorkloadRequest
class CreateWorkloadResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
WORKLOAD_ID_FIELD_NUMBER: builtins.int
workload_id: typing.Text
workload_id: builtins.int
def __init__(self,
*,
workload_id: typing.Text = ...,
workload_id: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ...
global___CreateWorkloadResponse = CreateWorkloadResponse
Expand All @@ -35,10 +35,10 @@ class GetNodesRequest(google.protobuf.message.Message):
"""GetNodes messages"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
WORKLOAD_ID_FIELD_NUMBER: builtins.int
workload_id: typing.Text
workload_id: builtins.int
def __init__(self,
*,
workload_id: typing.Text = ...,
workload_id: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ...
global___GetNodesRequest = GetNodesRequest
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/proto/task_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions src/py/flwr/proto/task_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ class TaskIns(google.protobuf.message.Message):
TASK_FIELD_NUMBER: builtins.int
task_id: typing.Text
group_id: typing.Text
workload_id: typing.Text
workload_id: builtins.int
@property
def task(self) -> global___Task: ...
def __init__(self,
*,
task_id: typing.Text = ...,
group_id: typing.Text = ...,
workload_id: typing.Text = ...,
workload_id: builtins.int = ...,
task: typing.Optional[global___Task] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["task",b"task"]) -> builtins.bool: ...
Expand All @@ -85,14 +85,14 @@ class TaskRes(google.protobuf.message.Message):
TASK_FIELD_NUMBER: builtins.int
task_id: typing.Text
group_id: typing.Text
workload_id: typing.Text
workload_id: builtins.int
@property
def task(self) -> global___Task: ...
def __init__(self,
*,
task_id: typing.Text = ...,
group_id: typing.Text = ...,
workload_id: typing.Text = ...,
workload_id: builtins.int = ...,
task: typing.Optional[global___Task] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["task",b"task"]) -> builtins.bool: ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_push_task_res() -> None:
TaskRes(
task_id="",
group_id="",
workload_id="",
workload_id=0,
task=Task(),
),
],
Expand Down
13 changes: 6 additions & 7 deletions src/py/flwr/server/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class InMemoryState(State):

def __init__(self) -> None:
self.node_ids: Set[int] = set()
self.workload_ids: Set[str] = set()
self.workload_ids: Set[int] = set()
self.task_ins_store: Dict[UUID, TaskIns] = {}
self.task_res_store: Dict[UUID, TaskRes] = {}

Expand Down Expand Up @@ -194,7 +194,7 @@ def unregister_node(self, node_id: int) -> None:
raise ValueError(f"Node {node_id} is not registered")
self.node_ids.remove(node_id)

def get_nodes(self, workload_id: str) -> Set[int]:
def get_nodes(self, workload_id: int) -> Set[int]:
"""Return all available client nodes.
Constraints
Expand All @@ -206,14 +206,13 @@ def get_nodes(self, workload_id: str) -> Set[int]:
return set()
return self.node_ids

def create_workload(self) -> str:
def create_workload(self) -> int:
"""Create one workload."""
# String representation of random integer from 0 to 9223372036854775807
random_workload_id: int = random.randrange(9223372036854775808)
workload_id = str(random_workload_id)
# Sample random integer from 0 to 9223372036854775807
workload_id: int = random.randrange(9223372036854775808)

if workload_id not in self.workload_ids:
self.workload_ids.add(workload_id)
return workload_id
log(ERROR, "Unexpected workload creation failure.")
return ""
return 0
Loading

0 comments on commit e7977b5

Please sign in to comment.