Skip to content

Commit

Permalink
Rename middleware layers to mods (#2911)
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Feb 9, 2024
1 parent 57da58a commit e409f63
Show file tree
Hide file tree
Showing 15 changed files with 141 additions and 138 deletions.
1 change: 1 addition & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def find_test_modules(package_path):
"writing-documentation": "contributor-how-to-write-documentation.html",
"apiref-binaries": "ref-api-cli.html",
"fedbn-example-pytorch-from-centralized-to-federated": "example-fedbn-pytorch-from-centralized-to-federated.html",
"how-to-use-built-in-middleware-layers": "how-to-use-built-in-mods.html",
# Restructuring: tutorials
"tutorial/Flower-0-What-is-FL": "tutorial-series-what-is-federated-learning.html",
"tutorial/Flower-1-Intro-to-FL-PyTorch": "tutorial-series-get-started-with-flower-pytorch.html",
Expand Down
87 changes: 0 additions & 87 deletions doc/source/how-to-use-built-in-middleware-layers.rst

This file was deleted.

89 changes: 89 additions & 0 deletions doc/source/how-to-use-built-in-mods.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
Use Built-in Mods
=================

**Note: This tutorial covers experimental features. The functionality and interfaces may change in future versions.**

In this tutorial, we will learn how to utilize built-in mods to augment the behavior of a ``FlowerCallable``. Mods (sometimes also called Modifiers) allow us to perform operations before and after a task is processed in the ``FlowerCallable``.

What are Mods?
--------------

A Mod is a callable that wraps around a ``FlowerCallable``. It can manipulate or inspect the incoming ``Message`` and the resulting outgoing ``Message``. The signature for a ``Mod`` is as follows:

.. code-block:: python
FlowerCallable = Callable[[Fwd], Bwd]
Mod = Callable[[Fwd, FlowerCallable], Bwd]
A typical mod function might look something like this:

.. code-block:: python
def example_mod(msg: Message, ctx: Context, nxt: FlowerCallable) -> Message:
# Do something with incoming Message (or Context)
# before passing to the inner ``FlowerCallable``
msg = nxt(msg, ctx)
# Do something with outgoing Message (or Context)
# before returning
return msg
Using Mods
----------

To use mods in your ``FlowerCallable``, you can follow these steps:

1. Import the required mods
~~~~~~~~~~~~~~~~~~~~~~~~~~~

First, import the built-in mod you intend to use:

.. code-block:: python
import flwr as fl
from flwr.client.mod import example_mod_1, example_mod_2
2. Define your client function
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Define your client function (``client_fn``) that will be wrapped by the mod(s):

.. code-block:: python
def client_fn(cid):
# Your client code goes here.
return # your client
3. Create the ``FlowerCallable`` with mods
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Create your ``FlowerCallable`` and pass the mods as a list to the ``mods`` argument. The order in which you provide the mods matters:

.. code-block:: python
flower = fl.app.Flower(
client_fn=client_fn,
mods=[
example_mod_1, # Mod 1
example_mod_2, # Mod 2
]
)
Order of execution
------------------

When the ``FlowerCallable`` runs, the mods are executed in the order they are provided in the list:

1. ``example_mod_1`` (outermost mod)
2. ``example_mod_2`` (next mod)
3. Message handler (core function that handles the incoming ``Message`` and returns the outgoing ``Message``)
4. ``example_mod_2`` (on the way back)
5. ``example_mod_1`` (outermost mod on the way back)

Each mod has a chance to inspect and modify the incoming ``Message`` before passing it to the next mod, and likewise with the outgoing ``Message`` before returning it up the stack.

Conclusion
----------

By following this guide, you have learned how to effectively use mods to enhance your ``FlowerCallable``'s functionality. Remember that the order of mods is crucial and affects how the input and output are processed.

Enjoy building more robust and flexible ``FlowerCallable``s with mods!
2 changes: 1 addition & 1 deletion doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Problem-oriented how-to guides show step-by-step how to achieve a specific goal.
how-to-configure-logging
how-to-enable-ssl-connections
how-to-upgrade-to-flower-1.0
how-to-use-built-in-middleware-layers
how-to-use-built-in-mods
how-to-run-flower-using-docker

.. toctree::
Expand Down
4 changes: 2 additions & 2 deletions examples/secaggplus-mt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import flwr as fl
from flwr.common import Status, FitIns, FitRes, Code
from flwr.common.parameter import ndarrays_to_parameters
from flwr.client.middleware import secaggplus_middleware
from flwr.client.mod import secaggplus_mod


# Define Flower client with the SecAgg+ protocol
Expand Down Expand Up @@ -35,7 +35,7 @@ def client_fn(cid: str):
# To run this: `flower-client --callable client:flower`
flower = fl.flower.Flower(
client_fn=client_fn,
layers=[secaggplus_middleware],
mods=[secaggplus_mod],
)


Expand Down
10 changes: 5 additions & 5 deletions src/py/flwr/client/flower.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from flwr.client.message_handler.message_handler import (
handle_legacy_message_from_tasktype,
)
from flwr.client.middleware.utils import make_ffn
from flwr.client.typing import ClientFn, Layer
from flwr.client.mod.utils import make_ffn
from flwr.client.typing import ClientFn, Mod
from flwr.common.context import Context
from flwr.common.message import Message

Expand Down Expand Up @@ -56,7 +56,7 @@ class Flower:
def __init__(
self,
client_fn: ClientFn, # Only for backward compatibility
layers: Optional[List[Layer]] = None,
mods: Optional[List[Mod]] = None,
) -> None:
# Create wrapper function for `handle`
def ffn(
Expand All @@ -68,8 +68,8 @@ def ffn(
)
return out_message

# Wrap middleware layers around the wrapped handle function
self._call = make_ffn(ffn, layers if layers is not None else [])
# Wrap mods around the wrapped handle function
self._call = make_ffn(ffn, mods if mods is not None else [])

def __call__(self, message: Message, context: Context) -> Message:
"""."""
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
def handle_legacy_message_from_tasktype(
client_fn: ClientFn, message: Message, context: Context
) -> Message:
"""Handle legacy message in the inner most middleware layer."""
"""Handle legacy message in the inner most mod."""
client = client_fn("-1")

client.set_context(context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Middleware layers."""
"""Mods."""


from .secure_aggregation.secaggplus_middleware import secaggplus_middleware
from .secure_aggregation.secaggplus_mod import secaggplus_mod
from .utils import make_ffn

__all__ = [
"make_ffn",
"secaggplus_middleware",
"secaggplus_mod",
]
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Secure Aggregation handlers."""
from .secaggplus_middleware import secaggplus_middleware
"""Secure Aggregation mods."""


from .secaggplus_mod import secaggplus_mod

__all__ = [
"secaggplus_middleware",
"secaggplus_mod",
]
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def fit() -> FitRes:
return fit


def secaggplus_middleware(
def secaggplus_mod(
msg: Message,
ctxt: Context,
call_next: FlowerCallable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from itertools import product
from typing import Callable, Dict, List

from flwr.client.middleware import make_ffn
from flwr.client.mod import make_ffn
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.constant import TASK_TYPE_FIT
from flwr.common.context import Context
Expand Down Expand Up @@ -47,7 +47,7 @@
)
from flwr.common.typing import ConfigsRecordValues

from .secaggplus_middleware import SecAggPlusState, check_configs, secaggplus_middleware
from .secaggplus_mod import SecAggPlusState, check_configs, secaggplus_mod


def get_test_handler(
Expand All @@ -61,7 +61,7 @@ def empty_ffn(_: Message, _2: Context) -> Message:
message=RecordSet(),
)

app = make_ffn(empty_ffn, [secaggplus_middleware])
app = make_ffn(empty_ffn, [secaggplus_mod])

def func(configs: Dict[str, ConfigsRecordValues]) -> Dict[str, ConfigsRecordValues]:
in_msg = Message(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for middleware layers."""
"""Utility functions for mods."""


from typing import List

from flwr.client.typing import FlowerCallable, Layer
from flwr.client.typing import FlowerCallable, Mod
from flwr.common.context import Context
from flwr.common.message import Message


def make_ffn(ffn: FlowerCallable, layers: List[Layer]) -> FlowerCallable:
def make_ffn(ffn: FlowerCallable, mods: List[Mod]) -> FlowerCallable:
"""."""

def wrap_ffn(_ffn: FlowerCallable, _layer: Layer) -> FlowerCallable:
def wrap_ffn(_ffn: FlowerCallable, _mod: Mod) -> FlowerCallable:
def new_ffn(message: Message, context: Context) -> Message:
return _layer(message, context, _ffn)
return _mod(message, context, _ffn)

return new_ffn

for layer in reversed(layers):
ffn = wrap_ffn(ffn, layer)
for mod in reversed(mods):
ffn = wrap_ffn(ffn, mod)

return ffn
Loading

0 comments on commit e409f63

Please sign in to comment.