Skip to content

Commit

Permalink
Prepare Task validation to be used in state tests (#1667)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanertopal authored Feb 15, 2023
1 parent c9b5b53 commit 5e13a56
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 151 deletions.
53 changes: 4 additions & 49 deletions src/py/flwr/server/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
PushTaskInsRequest,
PushTaskInsResponse,
)
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
from flwr.proto.task_pb2 import TaskRes
from flwr.server.state import State
from flwr.server.utils.validator import validate_task_ins_or_res


class DriverServicer(driver_pb2_grpc.DriverServicer):
Expand Down Expand Up @@ -61,7 +62,8 @@ def PushTaskIns(
# Validate request
_raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
for task_ins in request.task_ins_list:
_validate_incoming_task_ins(task_ins=task_ins)
validation_errors = validate_task_ins_or_res(task_ins)
_raise_if(bool(validation_errors), ", ".join(validation_errors))

# Store each TaskIns
task_ids: List[Optional[UUID]] = []
Expand Down Expand Up @@ -105,53 +107,6 @@ def on_rpc_done() -> None:
return PullTaskResResponse(task_res_list=task_res_list)


def _validate_incoming_task_ins(task_ins: TaskIns) -> None:
"""Validate incoming TaskIns."""

_raise_if(task_ins.task_id != "", "non-empty `task_id`")
_raise_if(not task_ins.HasField("task"), "`task` does not set field `task`")

task: Task = task_ins.task

# Task producer
_raise_if(not task.HasField("producer"), "`producer` does not set field `producer`")
_raise_if(task.producer.node_id != 0, "`producer.node_id` is not 0")
_raise_if(not task.producer.anonymous, "`producer` is not anonymous")

# Task consumer
_raise_if(not task.HasField("consumer"), "`consumer` does not set field `consumer`")
_raise_if(
task.consumer.anonymous and task.consumer.node_id != 0,
"anonymous consumers MUST NOT set a `node_id`",
)
_raise_if(
not task.consumer.anonymous and task.consumer.node_id == 0,
"non-anonymous consumer MUST provide a `node_id`",
)

# Created/delivered/TTL
_raise_if(task.created_at != "", "`created_at` must be an empty str")
_raise_if(task.delivered_at != "", "`delivered_at` must be an empty str")
_raise_if(task.ttl != "", "`ttl` must be an empty str")

# Legacy ServerMessage/ClientMessage
_raise_if(
task.HasField("legacy_client_message"),
"`legacy_client_message` is not `None`",
)
_raise_if(
not task.HasField("legacy_server_message"),
"`task` does not set field `legacy_server_message`",
)
_raise_if(
not task.legacy_server_message.HasField("msg"),
"`legacy_server_message` does not set field `msg`",
)

# Ancestors
_raise_if(len(task.ancestry) != 0, "`ancestry` is not empty")


def _raise_if(validation_error: bool, detail: str) -> None:
if validation_error:
raise ValueError(f"Malformed PushTaskInsRequest: {detail}")
103 changes: 1 addition & 102 deletions src/py/flwr/server/driver/driver_servicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@
"""DriverServicer tests."""


import uuid

from flwr.proto.node_pb2 import Node
from flwr.proto.task_pb2 import Task, TaskIns
from flwr.proto.transport_pb2 import ServerMessage
from flwr.server.driver.driver_servicer import _raise_if, _validate_incoming_task_ins

VALUE_ERROR_BASE: str = "Malformed PushTaskInsRequest: "
from flwr.server.driver.driver_servicer import _raise_if

# pylint: disable=broad-except

Expand Down Expand Up @@ -63,97 +56,3 @@ def test_raise_if_true() -> None:
assert str(err) == "Malformed PushTaskInsRequest: test"
except Exception:
assert False


def _create_task_ins(
task_id: str = "", task: bool = True, server_message: bool = True
) -> TaskIns:
return TaskIns(
task_id=task_id,
group_id="",
workload_id="",
task=Task(
producer=Node(node_id=0, anonymous=True),
consumer=Node(node_id=1, anonymous=False),
created_at="",
delivered_at="",
ttl="",
ancestry=[],
legacy_server_message=ServerMessage(fit_ins=ServerMessage.FitIns())
if server_message
else None,
legacy_client_message=None,
)
if task
else None,
)


def test_validate_incoming_task_ins_valid() -> None:
"""Test TaskIns validation."""

# Prepare
task_ins = _create_task_ins()

# Execute
try:
_validate_incoming_task_ins(task_ins=task_ins)

# Assert
assert True
except Exception:
assert False


def test_validate_incoming_task_ins_invalid_task_id_set() -> None:
"""Test TaskIns validation."""

# Prepare
task_ins = _create_task_ins(task_id=str(uuid.uuid4()))

# Execute
try:
_validate_incoming_task_ins(task_ins=task_ins)

# Assert
assert False
except ValueError as err:
assert str(err).startswith(VALUE_ERROR_BASE)
except Exception:
assert False


def test_validate_incoming_task_ins_invalid_no_task() -> None:
"""Test TaskIns validation."""

# Prepare
task_ins = _create_task_ins(task=False)

# Execute
try:
_validate_incoming_task_ins(task_ins=task_ins)

# Assert
assert False
except ValueError as err:
assert str(err).startswith(VALUE_ERROR_BASE)
except Exception:
assert False


def test_validate_incoming_task_ins_invalid_no_server_message() -> None:
"""Test TaskIns validation."""

# Prepare
task_ins = _create_task_ins(server_message=False)

# Execute
try:
_validate_incoming_task_ins(task_ins=task_ins)

# Assert
assert False
except ValueError as err:
assert str(err).startswith(VALUE_ERROR_BASE)
except Exception:
assert False
119 changes: 119 additions & 0 deletions src/py/flwr/server/utils/validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2023 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Validators."""
from typing import List, Union

from flwr.proto.task_pb2 import TaskIns, TaskRes


# pylint: disable-next=too-many-branches
def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str]:
"""Validate a TaskIns or TaskRes."""

validation_errors = []

if tasks_ins_res.task_id != "":
validation_errors.append("non-empty `task_id`")

if not tasks_ins_res.HasField("task"):
validation_errors.append("`task` does not set field `task`")

# Created/delivered/TTL
if tasks_ins_res.task.created_at != "":
validation_errors.append("`created_at` must be an empty str")
if tasks_ins_res.task.delivered_at != "":
validation_errors.append("`delivered_at` must be an empty str")
if tasks_ins_res.task.ttl != "":
validation_errors.append("`ttl` must be an empty str")

# TaskIns specific
if isinstance(tasks_ins_res, TaskIns):
# Task producer
if not tasks_ins_res.task.HasField("producer"):
validation_errors.append("`producer` does not set field `producer`")
if tasks_ins_res.task.producer.node_id != 0:
validation_errors.append("`producer.node_id` is not 0")
if not tasks_ins_res.task.producer.anonymous:
validation_errors.append("`producer` is not anonymous")

# Task consumer
if not tasks_ins_res.task.HasField("consumer"):
validation_errors.append("`consumer` does not set field `consumer`")
if (
tasks_ins_res.task.consumer.anonymous
and tasks_ins_res.task.consumer.node_id != 0
):
validation_errors.append("anonymous consumers MUST NOT set a `node_id`")
if (
not tasks_ins_res.task.consumer.anonymous
and tasks_ins_res.task.consumer.node_id == 0
):
validation_errors.append("non-anonymous consumer MUST provide a `node_id`")

# Legacy ServerMessage
if not tasks_ins_res.task.HasField("legacy_server_message"):
validation_errors.append(
"`task` does not set field `legacy_server_message`"
)
if not tasks_ins_res.task.legacy_server_message.HasField("msg"):
validation_errors.append("`legacy_server_message` does not set field `msg`")

# Ancestors
if len(tasks_ins_res.task.ancestry) != 0:
validation_errors.append("`ancestry` is not empty")

# TaskRes specific
if isinstance(tasks_ins_res, TaskRes):
# Task producer
if not tasks_ins_res.task.HasField("producer"):
validation_errors.append("`producer` does not set field `producer`")
if (
tasks_ins_res.task.producer.anonymous
and tasks_ins_res.task.producer.node_id != 0
):
validation_errors.append("anonymous producers MUST NOT set a `node_id`")
if (
not tasks_ins_res.task.producer.anonymous
and tasks_ins_res.task.producer.node_id == 0
):
validation_errors.append("non-anonymous producer MUST provide a `node_id`")

# Task consumer
if not tasks_ins_res.task.HasField("consumer"):
validation_errors.append("`consumer` does not set field `consumer`")
if (
tasks_ins_res.task.consumer.anonymous
and tasks_ins_res.task.consumer.node_id != 0
):
validation_errors.append("anonymous consumers MUST NOT set a `node_id`")
if (
not tasks_ins_res.task.consumer.anonymous
and tasks_ins_res.task.consumer.node_id == 0
):
validation_errors.append("non-anonymous consumer MUST provide a `node_id`")

# Legacy ClientMessage
if not tasks_ins_res.task.HasField("legacy_client_message"):
validation_errors.append(
"`task` does not set field `legacy_client_message`"
)
if not tasks_ins_res.task.legacy_client_message.HasField("msg"):
validation_errors.append("`legacy_client_message` does not set field `msg`")

# Ancestors
if len(tasks_ins_res.task.ancestry) == 0:
validation_errors.append("`ancestry` is empty")

return validation_errors
Loading

0 comments on commit 5e13a56

Please sign in to comment.