Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs(framework) Update Quickstart Tutorial documentation for JAX with flwr run #3367

Merged
merged 16 commits into from
Oct 16, 2024
Merged
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 115 additions & 55 deletions doc/source/tutorial-quickstart-jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,24 @@ Quickstart JAX
.. meta::
:description: Check out this Federated Learning quickstart tutorial for using Flower with Jax to train a linear regression model on a scikit-learn dataset.

This tutorial will show you how to use Flower to build a federated version of an existing JAX workload.
We are using JAX to train a linear regression model on a scikit-learn dataset.
Let's build a federated learning system using JAX and the Flower framework!

We will leverage JAX to train a linear regression model on a scikit-learn dataset.
We will structure the example similar to our `PyTorch - From Centralized To Federated <https://github.com/adap/flower/blob/main/examples/pytorch-from-centralized-to-federated>`_ walkthrough.
First, we build a centralized training approach based on the `Linear Regression with JAX <https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html>`_ tutorial`.
Then, we build upon the centralized training code to run the training in a federated fashion.
First, we build a centralized training approach based on the `Linear Regression with JAX <https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html>`_ tutorial.
Then, we build upon the centralized training code to run the training in a federated fashion over multiple clients using Flower.


Dependencies
------------

First of all, it is recommended to create a virtual environment and run everything within a :doc:`virtualenv <contributor-how-to-set-up-a-virtual-env>`.

Before we start building our JAX example, we need install the packages :code:`jax`, :code:`jaxlib`, :code:`scikit-learn`, and :code:`flwr`:
To follow along this tutorial you will need to install the following packages: :code:`jax`, :code:`jaxlib`, :code:`scikit-learn`, and :code:`flwr`. This can be done using :code:`pip`:

.. code-block:: shell

$ pip install jax jaxlib scikit-learn flwr
$ pip install flwr jax jaxlib scikit-learn


Linear Regression with JAX
Expand Down Expand Up @@ -76,8 +83,6 @@ We now need to define the training (function :code:`train()`), which loops over
grads = grad_fn(params, X, y)
params = jax.tree_multimap(lambda p, g: p - 0.05 * g, params, grads)
loss = loss_fn(params,X, y)
# if epochs % 10 == 9:
# print(f'For Epoch {epochs} loss {loss}')
return params, loss, num_examples

The evaluation of the model is defined in the function :code:`evaluation()`. The function takes all test examples and measures the loss of the linear regression model.
Expand All @@ -88,7 +93,6 @@ The evaluation of the model is defined in the function :code:`evaluation()`. The
num_examples = X_test.shape[0]
err_test = loss_fn(params, X_test, y_test)
loss_test = jnp.mean(jnp.square(err_test))
# print(f'Test loss {loss_test}')
return loss_test, num_examples

Having defined the data loading, model architecture, training, and evaluation we can put everything together and train our model using JAX. As already mentioned, the :code:`jax.grad()` function is defined in :code:`main()` and passed to :code:`train()`.
Expand All @@ -110,9 +114,9 @@ Having defined the data loading, model architecture, training, and evaluation we

You can now run your (centralized) JAX linear regression workload:

.. code-block:: python
.. code-block:: shell

python3 jax_training.py
$ python3 jax_training.py

So far this should all look fairly familiar if you've used JAX before.
Let's take the next step and use what we've built to create a simple federated learning system consisting of one server and two clients.
Expand All @@ -121,41 +125,39 @@ JAX meets Flower
----------------

The concept of federating an existing workload is always the same and easy to understand.
We have to start a *server* and then use the code in :code:`jax_training.py` for the *clients* that are connected to the *server*.
We have to define the Flower interface for the *clients* using the code in :code:`jax_training.py`. We also start a *server* for the *clients* to connect to.
The *server* sends model parameters to the clients. The *clients* run the training and update the parameters.
The updated parameters are sent back to the *server*, which averages all received parameter updates.
This describes one round of the federated learning process, and we repeat this for multiple rounds.

Our example consists of one *server* and two *clients*. Let's set up :code:`server.py` first. The *server* needs to import the Flower package :code:`flwr`.
Next, we use the :code:`start_server` function to start a server and tell it to perform three rounds of federated learning.
Our example consists of one *server* and two *clients*.

.. code-block:: python
Flower Client
^^^^^^^^^^^^^

import flwr as fl

if __name__ == "__main__":
fl.server.start_server(server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3))

We can already start the *server*:
First, we set up our *client* logic in :code:`client.py` by building upon the previously defined JAX training in :code:`jax_training.py`.
Our *client* needs to import :code:`flwr` and :code:`jax` to update the parameters on our JAX model:

.. code-block:: python

python3 server.py

Finally, we will define our *client* logic in :code:`client.py` and build upon the previously defined JAX training in :code:`jax_training.py`.
Our *client* needs to import :code:`flwr`, but also :code:`jax` and :code:`jaxlib` to update the parameters on our JAX model:

.. code-block:: python

from typing import Dict, List, Callable, Tuple
from typing import Dict, List, Tuple

import flwr as fl
import numpy as np
import jax
import jax.numpy as jnp
import numpy as np
from flwr.client import ClientApp

import jax_training

Now we load the data, set the loss function, and set the model shape:

.. code-block:: python

train_x, train_y, test_x, test_y = jax_training.load_data()
grad_fn = jax.grad(jax_training.loss_fn)
model_shape = train_x.shape[1:]

After preparing the data and model, we define the Flower interface.

Implementing a Flower *client* basically means implementing a subclass of either :code:`flwr.client.Client` or :code:`flwr.client.NumPyClient`.
Our implementation will be based on :code:`flwr.client.NumPyClient` and we'll call it :code:`FlowerClient`.
Expand Down Expand Up @@ -185,7 +187,6 @@ We included type annotations to give you a better understanding of the data type

.. code-block:: python


class FlowerClient(fl.client.NumPyClient):
"""Flower client implementing using linear regression and JAX."""

Expand Down Expand Up @@ -221,7 +222,6 @@ We included type annotations to give you a better understanding of the data type
value = item[1]
self.params[key] = value
return self.params


def fit(
self, parameters: List[np.ndarray], config: Dict
Expand All @@ -248,41 +248,101 @@ We included type annotations to give you a better understanding of the data type
{"loss": float(loss)},
)

Having defined the federation process, we can run it.
Next, we create a client function that returns instances of :code:`FlowerClient` on-demand when called:

.. code-block:: python

def main() -> None:
"""Load data, start MNISTClient."""
def client_fn(cid: str):
return FlowerClient().to_client()

# Load data
train_x, train_y, test_x, test_y = jax_training.load_data()
grad_fn = jax.grad(jax_training.loss_fn)
Finally, we create a :code:`ClientApp()` object that uses this client function:

# Load model (from centralized training) and initialize parameters
model_shape = train_x.shape[1:]
params = jax_training.load_model(model_shape)
.. code-block:: python

# Start Flower client
client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
app = ClientApp(client_fn=client_fn)

That's it for the client. We only have to implement :code:`Client` or :code:`NumPyClient`, create a :code:`ClientApp`, and pass the client function to it. If we implement a client of type :code:`NumPyClient` we'll need to first call its :code:`to_client()` method.

if __name__ == "__main__":
main()

Flower Server
^^^^^^^^^^^^^

And that's it. You can now open two additional terminal windows and run
For simple workloads, we create a :code:`ServerApp` and leave all the
configuration possibilities at their default values. In a file named
:code:`server.py`, import Flower and create a :code:`ServerApp`:

.. code-block:: python

python3 client.py
from flwr.server import ServerApp

app = ServerApp()


Train the model, federated!
---------------------------

With both :code:`ClientApps` and :code:`ServerApp` ready, we can now run everything and see federated
learning in action. First, we run the :code:`flower-superlink` command in one terminal to start the infrastructure. This step only needs to be run once.

.. admonition:: Note
:class: note

In this example, the :code:`--insecure` command line argument starts Flower without HTTPS and is only used for prototyping. To run with HTTPS, we instead use the argument :code:`--certificates` and pass the paths to the certificates. Please refer to `Flower CLI reference <ref-api-cli.html>`_ for implementation details.
chongshenng marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: shell

$ flower-superlink --insecure

in each window (make sure that the server is still running before you do so) and see your JAX project run federated learning across two clients. Congratulations!
FL systems usually have a server and multiple clients. We therefore need to start multiple `SuperNodes`, one for each client, respectively. First, we open a new terminal and start the first `SuperNode` using the :code:`flower-client-app` command.

Next Steps
----------
.. code-block:: shell

$ flower-client-app client:app --insecure

In the above, we launch the :code:`app` object in the :code:`client.py` module.
Open another terminal and start the second `SuperNode`:

.. code-block:: shell

$ flower-client-app client:app --insecure

The source code of this example was improved over time and can be found here: `Quickstart JAX <https://github.com/adap/flower/blob/main/examples/quickstart-jax>`_.
Our example is somewhat over-simplified because both clients load the same dataset.
Finally, in another terminal window, we run the `ServerApp`. This starts the actual training run:

.. code-block:: shell

$ flower-server-app server:app --insecure

We should now see how the training does in the last terminal (the one that started the :code:`ServerApp`):

.. code-block:: shell

You're now prepared to explore this topic further. How about using a more sophisticated model or using a different dataset? How about adding more clients?
WARNING : Option `--insecure` was set. Starting insecure HTTP client connected to 0.0.0.0:9091.
INFO : Starting Flower ServerApp, config: num_rounds=1, no round_timeout
INFO :
INFO : [INIT]
INFO : Requesting initial parameters from one random client
INFO : Received initial parameters from one random client
INFO : Evaluating initial global parameters
INFO :
INFO : [ROUND 1]
INFO : configure_fit: strategy sampled 2 clients (out of 2)
INFO : aggregate_fit: received 2 results and 0 failures
WARNING : No fit_metrics_aggregation_fn provided
INFO : configure_evaluate: strategy sampled 2 clients (out of 2)
INFO : aggregate_evaluate: received 2 results and 0 failures
WARNING : No evaluate_metrics_aggregation_fn provided
INFO :
INFO : [SUMMARY]
INFO : Run finished 1 rounds in 7.06s
INFO : History (loss, distributed):
INFO : '\tround 1: 0.15034367516636848\n'

Congratulations!
You've successfully built and run your first federated learning system with JAX.
The full source code for this example can be found in |quickstart_jax_link|_.

.. |quickstart_jax_link| replace:: :code:`examples/quickstart-jax`
.. _quickstart_jax_link: https://github.com/adap/flower/blob/main/examples/quickstart-jax

Of course, this is a very basic example, and a lot can be added or modified.
How about using a more sophisticated model or using a different dataset? How about adding more clients?