diff --git a/doc/source/tutorial-quickstart-jax.rst b/doc/source/tutorial-quickstart-jax.rst
index 0581e95d8d42..833270d5636f 100644
--- a/doc/source/tutorial-quickstart-jax.rst
+++ b/doc/source/tutorial-quickstart-jax.rst
@@ -3,324 +3,303 @@
Quickstart JAX
==============
-.. meta::
- :description: Check out this Federated Learning quickstart tutorial for using Flower with Jax to train a linear regression model on a scikit-learn dataset.
+In this federated learning tutorial we will learn how to train a linear regression model
+using Flower and `JAX `_. It is recommended to
+create a virtual environment and run everything within a :doc:`virtualenv
+`.
-This tutorial will show you how to use Flower to build a federated version of an
-existing JAX workload. We are using JAX to train a linear regression model on a
-scikit-learn dataset. We will structure the example similar to our `PyTorch - From
-Centralized To Federated
-`_
-walkthrough. First, we build a centralized training approach based on the `Linear
-Regression with JAX
-`_ tutorial`.
-Then, we build upon the centralized training code to run the training in a federated
-fashion.
+Let's use ``flwr new`` to create a complete Flower+JAX project. It will generate all the
+files needed to run, by default with the Flower Simulation Engine, a federation of 10
+nodes using |fedavg|_. A random regression dataset will be loaded from scikit-learn's
+|makeregression|_ function.
-Before we start building our JAX example, we need install the packages ``jax``,
-``jaxlib``, ``scikit-learn``, and ``flwr``:
+Now that we have a rough idea of what this example is about, let's get started. First,
+install Flower in your new environment:
.. code-block:: shell
- $ pip install jax jaxlib scikit-learn flwr
+ # In a new Python environment
+ $ pip install flwr
+
+Then, run the command below. You will be prompted to select one of the available
+templates (choose ``JAX``), give a name to your project, and type in your developer
+name:
-Linear Regression with JAX
---------------------------
+.. code-block:: shell
-We begin with a brief description of the centralized training code based on a ``Linear
-Regression`` model. If you want a more in-depth explanation of what's going on then have
-a look at the official `JAX documentation `_.
+ $ flwr new
-Let's create a new file called ``jax_training.py`` with all the components required for
-a traditional (centralized) linear regression training. First, the JAX packages ``jax``
-and ``jaxlib`` need to be imported. In addition, we need to import ``sklearn`` since we
-use ``make_regression`` for the dataset and ``train_test_split`` to split the dataset
-into a training and test set. You can see that we do not yet import the ``flwr`` package
-for federated learning. This will be done later.
+After running it you'll notice a new directory with your project name has been created.
+It should have the following structure:
-.. code-block:: python
+.. code-block:: shell
- from typing import Dict, List, Tuple, Callable
- import jax
- import jax.numpy as jnp
- from sklearn.datasets import make_regression
- from sklearn.model_selection import train_test_split
+
+ ├──
+ │ ├── __init__.py
+ │ ├── client_app.py # Defines your ClientApp
+ │ ├── server_app.py # Defines your ServerApp
+ │ └── task.py # Defines your model, training and data loading
+ ├── pyproject.toml # Project metadata like dependencies and configs
+ └── README.md
- key = jax.random.PRNGKey(0)
+If you haven't yet installed the project and its dependencies, you can do so by:
+
+.. code-block:: shell
+
+ # From the directory where your pyproject.toml is
+ $ pip install -e .
+
+To run the project, do:
+
+.. code-block:: shell
-The ``load_data()`` function loads the mentioned training and test sets.
+ # Run with default arguments
+ $ flwr run .
+
+With default arguments you will see an output like this one:
+
+.. code-block:: shell
+
+ Loading project configuration...
+ Success
+ INFO : Starting Flower ServerApp, config: num_rounds=3, no round_timeout
+ INFO :
+ INFO : [INIT]
+ INFO : Requesting initial parameters from one random client
+ INFO : Received initial parameters from one random client
+ INFO : Starting evaluation of initial global parameters
+ INFO : Evaluation returned no results (`None`)
+ INFO :
+ INFO : [ROUND 1]
+ INFO : configure_fit: strategy sampled 10 clients (out of 10)
+ INFO : aggregate_fit: received 10 results and 0 failures
+ WARNING : No fit_metrics_aggregation_fn provided
+ INFO : configure_evaluate: strategy sampled 10 clients (out of 10)
+ INFO : aggregate_evaluate: received 10 results and 0 failures
+ WARNING : No evaluate_metrics_aggregation_fn provided
+ INFO :
+ INFO : [ROUND 2]
+ INFO : configure_fit: strategy sampled 10 clients (out of 10)
+ INFO : aggregate_fit: received 10 results and 0 failures
+ INFO : configure_evaluate: strategy sampled 10 clients (out of 10)
+ INFO : aggregate_evaluate: received 10 results and 0 failures
+ INFO :
+ INFO : [ROUND 3]
+ INFO : configure_fit: strategy sampled 10 clients (out of 10)
+ INFO : aggregate_fit: received 10 results and 0 failures
+ INFO : configure_evaluate: strategy sampled 10 clients (out of 10)
+ INFO : aggregate_evaluate: received 10 results and 0 failures
+ INFO :
+ INFO : [SUMMARY]
+ INFO : Run finished 3 round(s) in 6.07s
+ INFO : History (loss, distributed):
+ INFO : round 1: 0.29372873306274416
+ INFO : round 2: 5.820648354415425e-08
+ INFO : round 3: 1.526226667528834e-14
+ INFO :
+
+You can also override the parameters defined in the ``[tool.flwr.app.config]`` section
+in ``pyproject.toml`` like this:
+
+.. code-block:: shell
+
+ # Override some arguments
+ $ flwr run . --run-config "num-server-rounds=5 input-dim=5"
+
+What follows is an explanation of each component in the project you just created:
+dataset partition, the model, defining the ``ClientApp`` and defining the ``ServerApp``.
+
+The Data
+--------
+
+This tutorial uses scikit-learn's |makeregression|_ function to generate a random
+regression problem.
.. code-block:: python
- def load_data() -> (
- Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[np.ndarray]]
- ):
- # create our dataset and start with similar datasets for different clients
+ def load_data():
+ # Load dataset
X, y = make_regression(n_features=3, random_state=0)
X, X_test, y, y_test = train_test_split(X, y)
return X, y, X_test, y_test
-The model architecture (a very simple ``Linear Regression`` model) is defined in
-``load_model()``.
+The Model
+---------
+
+We defined a simple linear regression model to demonstrate how to create a JAX model,
+but feel free to replace it with a more sophisticated JAX model if you'd like, (such as
+with NN-based `Flax `_):
.. code-block:: python
- def load_model(model_shape) -> Dict:
- # model weights
+ def load_model(model_shape):
+ # Extract model parameters
params = {"b": jax.random.uniform(key), "w": jax.random.uniform(key, model_shape)}
return params
-We now need to define the training (function ``train()``), which loops over the training
-set and measures the loss (function ``loss_fn()``) for each batch of training examples.
-The loss function is separate since JAX takes derivatives with a ``grad()`` function
-(defined in the ``main()`` function and called in ``train()``).
+In addition to defining the model architecture, we also include two utility functions to
+perform both training (i.e. ``train()``) and evaluation (i.e. ``evaluation()``) using
+the above model.
.. code-block:: python
- def loss_fn(params, X, y) -> Callable:
+ def loss_fn(params, X, y):
+ # Return MSE as loss
err = jnp.dot(X, params["w"]) + params["b"] - y
- return jnp.mean(jnp.square(err)) # mse
+ return jnp.mean(jnp.square(err))
- def train(params, grad_fn, X, y) -> Tuple[np.array, float, int]:
+ def train(params, grad_fn, X, y):
+ loss = 1_000_000
num_examples = X.shape[0]
- for epochs in range(10):
+ for epochs in range(50):
grads = grad_fn(params, X, y)
- params = jax.tree_multimap(lambda p, g: p - 0.05 * g, params, grads)
+ params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
loss = loss_fn(params, X, y)
- # if epochs % 10 == 9:
- # print(f'For Epoch {epochs} loss {loss}')
return params, loss, num_examples
-The evaluation of the model is defined in the function ``evaluation()``. The function
-takes all test examples and measures the loss of the linear regression model.
-.. code-block:: python
-
- def evaluation(params, grad_fn, X_test, y_test) -> Tuple[float, int]:
+ def evaluation(params, grad_fn, X_test, y_test):
num_examples = X_test.shape[0]
err_test = loss_fn(params, X_test, y_test)
loss_test = jnp.mean(jnp.square(err_test))
- # print(f'Test loss {loss_test}')
return loss_test, num_examples
-Having defined the data loading, model architecture, training, and evaluation we can put
-everything together and train our model using JAX. As already mentioned, the
-``jax.grad()`` function is defined in ``main()`` and passed to ``train()``.
+The ClientApp
+-------------
+
+The main changes we have to make to use JAX with Flower will be found in the
+``get_params()`` and ``set_params()`` functions. In ``get_params()``, JAX model
+parameters are extracted and represented as a list of NumPy arrays. The ``set_params()``
+function is the opposite: given a list of NumPy arrays it applies them to an existing
+JAX model.
+
+.. note::
+
+ The ``get_params()`` and ``set_params()`` functions here are conceptually similar to
+ the ``get_weights()`` and ``set_weights()`` functions that we defined in the
+ :doc:`QuickStart PyTorch ` tutorial.
.. code-block:: python
- def main():
- X, y, X_test, y_test = load_data()
- model_shape = X.shape[1:]
- grad_fn = jax.grad(loss_fn)
- print("Model Shape", model_shape)
- params = load_model(model_shape)
- params, loss, num_examples = train(params, grad_fn, X, y)
- evaluation(params, grad_fn, X_test, y_test)
+ def get_params(params):
+ parameters = []
+ for _, val in params.items():
+ parameters.append(np.array(val))
+ return parameters
- if __name__ == "__main__":
- main()
+ def set_params(local_params, global_params):
+ for key, value in list(zip(local_params.keys(), global_params)):
+ local_params[key] = value
-You can now run your (centralized) JAX linear regression workload:
+The rest of the functionality is directly inspired by the centralized case. The
+``fit()`` method in the client trains the model using the local dataset. Similarly, the
+``evaluate()`` method is used to evaluate the model received on a held-out validation
+set that the client might have:
-.. code-block:: bash
+.. code-block:: python
- python3 jax_training.py
+ class FlowerClient(NumPyClient):
+ def __init__(self, input_dim):
+ self.train_x, self.train_y, self.test_x, self.test_y = load_data()
+ self.grad_fn = jax.grad(loss_fn)
+ model_shape = self.train_x.shape[1:]
-So far this should all look fairly familiar if you've used JAX before. Let's take the
-next step and use what we've built to create a simple federated learning system
-consisting of one server and two clients.
+ self.params = load_model(model_shape)
-JAX meets Flower
-----------------
+ def fit(self, parameters, config):
+ set_params(self.params, parameters)
+ self.params, loss, num_examples = train(
+ self.params, self.grad_fn, self.train_x, self.train_y
+ )
+ parameters = get_params({})
+ return parameters, num_examples, {"loss": float(loss)}
-The concept of federating an existing workload is always the same and easy to
-understand. We have to start a *server* and then use the code in ``jax_training.py`` for
-the *clients* that are connected to the *server*. The *server* sends model parameters to
-the clients. The *clients* run the training and update the parameters. The updated
-parameters are sent back to the *server*, which averages all received parameter updates.
-This describes one round of the federated learning process, and we repeat this for
-multiple rounds.
+ def evaluate(self, parameters, config):
+ set_params(self.params, parameters)
+ loss, num_examples = evaluation(
+ self.params, self.grad_fn, self.test_x, self.test_y
+ )
+ return float(loss), num_examples, {"loss": float(loss)}
-Our example consists of one *server* and two *clients*. Let's set up ``server.py``
-first. The *server* needs to import the Flower package ``flwr``. Next, we use the
-``start_server`` function to start a server and tell it to perform three rounds of
-federated learning.
+Finally, we can construct a ``ClientApp`` using the ``FlowerClient`` defined above by
+means of a ``client_fn()`` callback. Note that the `context` enables you to get access
+to hyperparemeters defined in your ``pyproject.toml`` to configure the run. In this
+tutorial we access the ``local-epochs`` setting to control the number of epochs a
+``ClientApp`` will perform when running the ``fit()`` method. You could define
+additioinal hyperparameters in ``pyproject.toml`` and access them here.
.. code-block:: python
- import flwr as fl
-
- if __name__ == "__main__":
- fl.server.start_server(
- server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3)
- )
+ def client_fn(context: Context):
+ input_dim = context.run_config["input-dim"]
+ # Return Client instance
+ return FlowerClient(input_dim).to_client()
-We can already start the *server*:
-.. code-block:: bash
+ # Flower ClientApp
+ app = ClientApp(client_fn)
- python3 server.py
+The ServerApp
+-------------
-Finally, we will define our *client* logic in ``client.py`` and build upon the
-previously defined JAX training in ``jax_training.py``. Our *client* needs to import
-``flwr``, but also ``jax`` and ``jaxlib`` to update the parameters on our JAX model:
+To construct a ``ServerApp`` we define a ``server_fn()`` callback with an identical
+signature to that of ``client_fn()`` but the return type is |serverappcomponents|_ as
+opposed to a |client|_ In this example we use the ``FedAvg`` strategy. To it we pass a
+randomly initialized model that will server as the global model to federated. Note that
+the value of ``input_dim`` is read from the run config. You can find the default value
+defined in the ``pyproject.toml``.
.. code-block:: python
- from typing import Dict, List, Callable, Tuple
-
- import flwr as fl
- import numpy as np
- import jax
- import jax.numpy as jnp
-
- import jax_training
-
-Implementing a Flower *client* basically means implementing a subclass of either
-``flwr.client.Client`` or ``flwr.client.NumPyClient``. Our implementation will be based
-on ``flwr.client.NumPyClient`` and we'll call it ``FlowerClient``. ``NumPyClient`` is
-slightly easier to implement than ``Client`` if you use a framework with good NumPy
-interoperability (like JAX) because it avoids some of the boilerplate that would
-otherwise be necessary. ``FlowerClient`` needs to implement four methods, two methods
-for getting/setting model parameters, one method for training the model, and one method
-for testing the model:
-
-1. ``set_parameters (optional)``
- - set the model parameters on the local model that are received from the server
- - transform parameters to NumPy ``ndarray``'s
- - loop over the list of model parameters received as NumPy ``ndarray``'s (think
- list of neural network layers)
-2. ``get_parameters``
- - get the model parameters and return them as a list of NumPy ``ndarray``'s
- (which is what ``flwr.client.NumPyClient`` expects)
-3. ``fit``
- - update the parameters of the local model with the parameters received from the
- server
- - train the model on the local training set
- - get the updated local model parameters and return them to the server
-4. ``evaluate``
- - update the parameters of the local model with the parameters received from the
- server
- - evaluate the updated model on the local test set
- - return the local loss to the server
-
-The challenging part is to transform the JAX model parameters from ``DeviceArray`` to
-``NumPy ndarray`` to make them compatible with `NumPyClient`.
-
-The two ``NumPyClient`` methods ``fit`` and ``evaluate`` make use of the functions
-``train()`` and ``evaluate()`` previously defined in ``jax_training.py``. So what we
-really do here is we tell Flower through our ``NumPyClient`` subclass which of our
-already defined functions to call for training and evaluation. We included type
-annotations to give you a better understanding of the data types that get passed around.
+ def server_fn(context: Context):
+ # Read from config
+ num_rounds = context.run_config["num-server-rounds"]
+ input_dim = context.run_config["input-dim"]
-.. code-block:: python
+ # Initialize global model
+ params = get_params(load_model((input_dim,)))
+ initial_parameters = ndarrays_to_parameters(params)
- class FlowerClient(fl.client.NumPyClient):
- """Flower client implementing using linear regression and JAX."""
-
- def __init__(
- self,
- params: Dict,
- grad_fn: Callable,
- train_x: List[np.ndarray],
- train_y: List[np.ndarray],
- test_x: List[np.ndarray],
- test_y: List[np.ndarray],
- ) -> None:
- self.params = params
- self.grad_fn = grad_fn
- self.train_x = train_x
- self.train_y = train_y
- self.test_x = test_x
- self.test_y = test_y
-
- def get_parameters(self, config) -> Dict:
- # Return model parameters as a list of NumPy ndarrays
- parameter_value = []
- for _, val in self.params.items():
- parameter_value.append(np.array(val))
- return parameter_value
-
- def set_parameters(self, parameters: List[np.ndarray]) -> Dict:
- # Collect model parameters and update the parameters of the local model
- value = jnp.ndarray
- params_item = list(zip(self.params.keys(), parameters))
- for item in params_item:
- key = item[0]
- value = item[1]
- self.params[key] = value
- return self.params
-
- def fit(
- self, parameters: List[np.ndarray], config: Dict
- ) -> Tuple[List[np.ndarray], int, Dict]:
- # Set model parameters, train model, return updated model parameters
- print("Start local training")
- self.params = self.set_parameters(parameters)
- self.params, loss, num_examples = jax_training.train(
- self.params, self.grad_fn, self.train_x, self.train_y
- )
- results = {"loss": float(loss)}
- print("Training results", results)
- return self.get_parameters(config={}), num_examples, results
-
- def evaluate(
- self, parameters: List[np.ndarray], config: Dict
- ) -> Tuple[float, int, Dict]:
- # Set model parameters, evaluate the model on a local test dataset, return result
- print("Start evaluation")
- self.params = self.set_parameters(parameters)
- loss, num_examples = jax_training.evaluation(
- self.params, self.grad_fn, self.test_x, self.test_y
- )
- print("Evaluation accuracy & loss", loss)
- return (
- float(loss),
- num_examples,
- {"loss": float(loss)},
- )
+ # Define strategy
+ strategy = FedAvg(initial_parameters=initial_parameters)
+ config = ServerConfig(num_rounds=num_rounds)
-Having defined the federation process, we can run it.
+ return ServerAppComponents(strategy=strategy, config=config)
-.. code-block:: python
- def main() -> None:
- """Load data, start MNISTClient."""
+ # Create ServerApp
+ app = ServerApp(server_fn=server_fn)
+
+Congratulations! You've successfully built and run your first federated learning system
+for JAX with Flower!
+
+.. note::
- # Load data
- train_x, train_y, test_x, test_y = jax_training.load_data()
- grad_fn = jax.grad(jax_training.loss_fn)
+ Check the source code of the extended version of this tutorial in
+ |quickstart_jax_link|_ in the Flower GitHub repository.
- # Load model (from centralized training) and initialize parameters
- model_shape = train_x.shape[1:]
- params = jax_training.load_model(model_shape)
+.. |client| replace:: ``Client``
- # Start Flower client
- client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y)
- fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
+.. |fedavg| replace:: ``FedAvg``
+.. |makeregression| replace:: ``make_regression()``
- if __name__ == "__main__":
- main()
+.. |quickstart_jax_link| replace:: ``examples/quickstart-jax``
-And that's it. You can now open two additional terminal windows and run
+.. |serverappcomponents| replace:: ``ServerAppComponents``
-.. code-block:: bash
+.. _client: ref-api/flwr.client.Client.html#client
- python3 client.py
+.. _fedavg: ref-api/flwr.server.strategy.FedAvg.html#flwr.server.strategy.FedAvg
-in each window (make sure that the server is still running before you do so) and see
-your JAX project run federated learning across two clients. Congratulations!
+.. _makeregression: https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html
-Next Steps
-----------
+.. _quickstart_jax_link: https://github.com/adap/flower/tree/main/examples/quickstart-jax
-The source code of this example was improved over time and can be found here:
-`Quickstart JAX `_.
-Our example is somewhat over-simplified because both clients load the same dataset.
+.. _serverappcomponents: ref-api/flwr.server.ServerAppComponents.html#serverappcomponents
-You're now prepared to explore this topic further. How about using a more sophisticated
-model or using a different dataset? How about adding more clients?
+.. meta::
+ :description: Check out this Federated Learning quickstart tutorial for using Flower with Jax to train a linear regression model on a scikit-learn dataset.