From 2a7547dd5e6bdfb6f2ca94ccd34ba513eba56c52 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 16 Oct 2024 19:44:04 +0200 Subject: [PATCH] docs(framework) Update Quickstart Tutorial documentation for JAX with `flwr run` (#3367) Co-authored-by: jafermarq --- doc/source/tutorial-quickstart-jax.rst | 459 ++++++++++++------------- 1 file changed, 219 insertions(+), 240 deletions(-) diff --git a/doc/source/tutorial-quickstart-jax.rst b/doc/source/tutorial-quickstart-jax.rst index 0581e95d8d4..833270d5636 100644 --- a/doc/source/tutorial-quickstart-jax.rst +++ b/doc/source/tutorial-quickstart-jax.rst @@ -3,324 +3,303 @@ Quickstart JAX ============== -.. meta:: - :description: Check out this Federated Learning quickstart tutorial for using Flower with Jax to train a linear regression model on a scikit-learn dataset. +In this federated learning tutorial we will learn how to train a linear regression model +using Flower and `JAX `_. It is recommended to +create a virtual environment and run everything within a :doc:`virtualenv +`. -This tutorial will show you how to use Flower to build a federated version of an -existing JAX workload. We are using JAX to train a linear regression model on a -scikit-learn dataset. We will structure the example similar to our `PyTorch - From -Centralized To Federated -`_ -walkthrough. First, we build a centralized training approach based on the `Linear -Regression with JAX -`_ tutorial`. -Then, we build upon the centralized training code to run the training in a federated -fashion. +Let's use ``flwr new`` to create a complete Flower+JAX project. It will generate all the +files needed to run, by default with the Flower Simulation Engine, a federation of 10 +nodes using |fedavg|_. A random regression dataset will be loaded from scikit-learn's +|makeregression|_ function. -Before we start building our JAX example, we need install the packages ``jax``, -``jaxlib``, ``scikit-learn``, and ``flwr``: +Now that we have a rough idea of what this example is about, let's get started. First, +install Flower in your new environment: .. code-block:: shell - $ pip install jax jaxlib scikit-learn flwr + # In a new Python environment + $ pip install flwr + +Then, run the command below. You will be prompted to select one of the available +templates (choose ``JAX``), give a name to your project, and type in your developer +name: -Linear Regression with JAX --------------------------- +.. code-block:: shell -We begin with a brief description of the centralized training code based on a ``Linear -Regression`` model. If you want a more in-depth explanation of what's going on then have -a look at the official `JAX documentation `_. + $ flwr new -Let's create a new file called ``jax_training.py`` with all the components required for -a traditional (centralized) linear regression training. First, the JAX packages ``jax`` -and ``jaxlib`` need to be imported. In addition, we need to import ``sklearn`` since we -use ``make_regression`` for the dataset and ``train_test_split`` to split the dataset -into a training and test set. You can see that we do not yet import the ``flwr`` package -for federated learning. This will be done later. +After running it you'll notice a new directory with your project name has been created. +It should have the following structure: -.. code-block:: python +.. code-block:: shell - from typing import Dict, List, Tuple, Callable - import jax - import jax.numpy as jnp - from sklearn.datasets import make_regression - from sklearn.model_selection import train_test_split + + ├── + │ ├── __init__.py + │ ├── client_app.py # Defines your ClientApp + │ ├── server_app.py # Defines your ServerApp + │ └── task.py # Defines your model, training and data loading + ├── pyproject.toml # Project metadata like dependencies and configs + └── README.md - key = jax.random.PRNGKey(0) +If you haven't yet installed the project and its dependencies, you can do so by: + +.. code-block:: shell + + # From the directory where your pyproject.toml is + $ pip install -e . + +To run the project, do: + +.. code-block:: shell -The ``load_data()`` function loads the mentioned training and test sets. + # Run with default arguments + $ flwr run . + +With default arguments you will see an output like this one: + +.. code-block:: shell + + Loading project configuration... + Success + INFO : Starting Flower ServerApp, config: num_rounds=3, no round_timeout + INFO : + INFO : [INIT] + INFO : Requesting initial parameters from one random client + INFO : Received initial parameters from one random client + INFO : Starting evaluation of initial global parameters + INFO : Evaluation returned no results (`None`) + INFO : + INFO : [ROUND 1] + INFO : configure_fit: strategy sampled 10 clients (out of 10) + INFO : aggregate_fit: received 10 results and 0 failures + WARNING : No fit_metrics_aggregation_fn provided + INFO : configure_evaluate: strategy sampled 10 clients (out of 10) + INFO : aggregate_evaluate: received 10 results and 0 failures + WARNING : No evaluate_metrics_aggregation_fn provided + INFO : + INFO : [ROUND 2] + INFO : configure_fit: strategy sampled 10 clients (out of 10) + INFO : aggregate_fit: received 10 results and 0 failures + INFO : configure_evaluate: strategy sampled 10 clients (out of 10) + INFO : aggregate_evaluate: received 10 results and 0 failures + INFO : + INFO : [ROUND 3] + INFO : configure_fit: strategy sampled 10 clients (out of 10) + INFO : aggregate_fit: received 10 results and 0 failures + INFO : configure_evaluate: strategy sampled 10 clients (out of 10) + INFO : aggregate_evaluate: received 10 results and 0 failures + INFO : + INFO : [SUMMARY] + INFO : Run finished 3 round(s) in 6.07s + INFO : History (loss, distributed): + INFO : round 1: 0.29372873306274416 + INFO : round 2: 5.820648354415425e-08 + INFO : round 3: 1.526226667528834e-14 + INFO : + +You can also override the parameters defined in the ``[tool.flwr.app.config]`` section +in ``pyproject.toml`` like this: + +.. code-block:: shell + + # Override some arguments + $ flwr run . --run-config "num-server-rounds=5 input-dim=5" + +What follows is an explanation of each component in the project you just created: +dataset partition, the model, defining the ``ClientApp`` and defining the ``ServerApp``. + +The Data +-------- + +This tutorial uses scikit-learn's |makeregression|_ function to generate a random +regression problem. .. code-block:: python - def load_data() -> ( - Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[np.ndarray]] - ): - # create our dataset and start with similar datasets for different clients + def load_data(): + # Load dataset X, y = make_regression(n_features=3, random_state=0) X, X_test, y, y_test = train_test_split(X, y) return X, y, X_test, y_test -The model architecture (a very simple ``Linear Regression`` model) is defined in -``load_model()``. +The Model +--------- + +We defined a simple linear regression model to demonstrate how to create a JAX model, +but feel free to replace it with a more sophisticated JAX model if you'd like, (such as +with NN-based `Flax `_): .. code-block:: python - def load_model(model_shape) -> Dict: - # model weights + def load_model(model_shape): + # Extract model parameters params = {"b": jax.random.uniform(key), "w": jax.random.uniform(key, model_shape)} return params -We now need to define the training (function ``train()``), which loops over the training -set and measures the loss (function ``loss_fn()``) for each batch of training examples. -The loss function is separate since JAX takes derivatives with a ``grad()`` function -(defined in the ``main()`` function and called in ``train()``). +In addition to defining the model architecture, we also include two utility functions to +perform both training (i.e. ``train()``) and evaluation (i.e. ``evaluation()``) using +the above model. .. code-block:: python - def loss_fn(params, X, y) -> Callable: + def loss_fn(params, X, y): + # Return MSE as loss err = jnp.dot(X, params["w"]) + params["b"] - y - return jnp.mean(jnp.square(err)) # mse + return jnp.mean(jnp.square(err)) - def train(params, grad_fn, X, y) -> Tuple[np.array, float, int]: + def train(params, grad_fn, X, y): + loss = 1_000_000 num_examples = X.shape[0] - for epochs in range(10): + for epochs in range(50): grads = grad_fn(params, X, y) - params = jax.tree_multimap(lambda p, g: p - 0.05 * g, params, grads) + params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads) loss = loss_fn(params, X, y) - # if epochs % 10 == 9: - # print(f'For Epoch {epochs} loss {loss}') return params, loss, num_examples -The evaluation of the model is defined in the function ``evaluation()``. The function -takes all test examples and measures the loss of the linear regression model. -.. code-block:: python - - def evaluation(params, grad_fn, X_test, y_test) -> Tuple[float, int]: + def evaluation(params, grad_fn, X_test, y_test): num_examples = X_test.shape[0] err_test = loss_fn(params, X_test, y_test) loss_test = jnp.mean(jnp.square(err_test)) - # print(f'Test loss {loss_test}') return loss_test, num_examples -Having defined the data loading, model architecture, training, and evaluation we can put -everything together and train our model using JAX. As already mentioned, the -``jax.grad()`` function is defined in ``main()`` and passed to ``train()``. +The ClientApp +------------- + +The main changes we have to make to use JAX with Flower will be found in the +``get_params()`` and ``set_params()`` functions. In ``get_params()``, JAX model +parameters are extracted and represented as a list of NumPy arrays. The ``set_params()`` +function is the opposite: given a list of NumPy arrays it applies them to an existing +JAX model. + +.. note:: + + The ``get_params()`` and ``set_params()`` functions here are conceptually similar to + the ``get_weights()`` and ``set_weights()`` functions that we defined in the + :doc:`QuickStart PyTorch ` tutorial. .. code-block:: python - def main(): - X, y, X_test, y_test = load_data() - model_shape = X.shape[1:] - grad_fn = jax.grad(loss_fn) - print("Model Shape", model_shape) - params = load_model(model_shape) - params, loss, num_examples = train(params, grad_fn, X, y) - evaluation(params, grad_fn, X_test, y_test) + def get_params(params): + parameters = [] + for _, val in params.items(): + parameters.append(np.array(val)) + return parameters - if __name__ == "__main__": - main() + def set_params(local_params, global_params): + for key, value in list(zip(local_params.keys(), global_params)): + local_params[key] = value -You can now run your (centralized) JAX linear regression workload: +The rest of the functionality is directly inspired by the centralized case. The +``fit()`` method in the client trains the model using the local dataset. Similarly, the +``evaluate()`` method is used to evaluate the model received on a held-out validation +set that the client might have: -.. code-block:: bash +.. code-block:: python - python3 jax_training.py + class FlowerClient(NumPyClient): + def __init__(self, input_dim): + self.train_x, self.train_y, self.test_x, self.test_y = load_data() + self.grad_fn = jax.grad(loss_fn) + model_shape = self.train_x.shape[1:] -So far this should all look fairly familiar if you've used JAX before. Let's take the -next step and use what we've built to create a simple federated learning system -consisting of one server and two clients. + self.params = load_model(model_shape) -JAX meets Flower ----------------- + def fit(self, parameters, config): + set_params(self.params, parameters) + self.params, loss, num_examples = train( + self.params, self.grad_fn, self.train_x, self.train_y + ) + parameters = get_params({}) + return parameters, num_examples, {"loss": float(loss)} -The concept of federating an existing workload is always the same and easy to -understand. We have to start a *server* and then use the code in ``jax_training.py`` for -the *clients* that are connected to the *server*. The *server* sends model parameters to -the clients. The *clients* run the training and update the parameters. The updated -parameters are sent back to the *server*, which averages all received parameter updates. -This describes one round of the federated learning process, and we repeat this for -multiple rounds. + def evaluate(self, parameters, config): + set_params(self.params, parameters) + loss, num_examples = evaluation( + self.params, self.grad_fn, self.test_x, self.test_y + ) + return float(loss), num_examples, {"loss": float(loss)} -Our example consists of one *server* and two *clients*. Let's set up ``server.py`` -first. The *server* needs to import the Flower package ``flwr``. Next, we use the -``start_server`` function to start a server and tell it to perform three rounds of -federated learning. +Finally, we can construct a ``ClientApp`` using the ``FlowerClient`` defined above by +means of a ``client_fn()`` callback. Note that the `context` enables you to get access +to hyperparemeters defined in your ``pyproject.toml`` to configure the run. In this +tutorial we access the ``local-epochs`` setting to control the number of epochs a +``ClientApp`` will perform when running the ``fit()`` method. You could define +additioinal hyperparameters in ``pyproject.toml`` and access them here. .. code-block:: python - import flwr as fl - - if __name__ == "__main__": - fl.server.start_server( - server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3) - ) + def client_fn(context: Context): + input_dim = context.run_config["input-dim"] + # Return Client instance + return FlowerClient(input_dim).to_client() -We can already start the *server*: -.. code-block:: bash + # Flower ClientApp + app = ClientApp(client_fn) - python3 server.py +The ServerApp +------------- -Finally, we will define our *client* logic in ``client.py`` and build upon the -previously defined JAX training in ``jax_training.py``. Our *client* needs to import -``flwr``, but also ``jax`` and ``jaxlib`` to update the parameters on our JAX model: +To construct a ``ServerApp`` we define a ``server_fn()`` callback with an identical +signature to that of ``client_fn()`` but the return type is |serverappcomponents|_ as +opposed to a |client|_ In this example we use the ``FedAvg`` strategy. To it we pass a +randomly initialized model that will server as the global model to federated. Note that +the value of ``input_dim`` is read from the run config. You can find the default value +defined in the ``pyproject.toml``. .. code-block:: python - from typing import Dict, List, Callable, Tuple - - import flwr as fl - import numpy as np - import jax - import jax.numpy as jnp - - import jax_training - -Implementing a Flower *client* basically means implementing a subclass of either -``flwr.client.Client`` or ``flwr.client.NumPyClient``. Our implementation will be based -on ``flwr.client.NumPyClient`` and we'll call it ``FlowerClient``. ``NumPyClient`` is -slightly easier to implement than ``Client`` if you use a framework with good NumPy -interoperability (like JAX) because it avoids some of the boilerplate that would -otherwise be necessary. ``FlowerClient`` needs to implement four methods, two methods -for getting/setting model parameters, one method for training the model, and one method -for testing the model: - -1. ``set_parameters (optional)`` - - set the model parameters on the local model that are received from the server - - transform parameters to NumPy ``ndarray``'s - - loop over the list of model parameters received as NumPy ``ndarray``'s (think - list of neural network layers) -2. ``get_parameters`` - - get the model parameters and return them as a list of NumPy ``ndarray``'s - (which is what ``flwr.client.NumPyClient`` expects) -3. ``fit`` - - update the parameters of the local model with the parameters received from the - server - - train the model on the local training set - - get the updated local model parameters and return them to the server -4. ``evaluate`` - - update the parameters of the local model with the parameters received from the - server - - evaluate the updated model on the local test set - - return the local loss to the server - -The challenging part is to transform the JAX model parameters from ``DeviceArray`` to -``NumPy ndarray`` to make them compatible with `NumPyClient`. - -The two ``NumPyClient`` methods ``fit`` and ``evaluate`` make use of the functions -``train()`` and ``evaluate()`` previously defined in ``jax_training.py``. So what we -really do here is we tell Flower through our ``NumPyClient`` subclass which of our -already defined functions to call for training and evaluation. We included type -annotations to give you a better understanding of the data types that get passed around. + def server_fn(context: Context): + # Read from config + num_rounds = context.run_config["num-server-rounds"] + input_dim = context.run_config["input-dim"] -.. code-block:: python + # Initialize global model + params = get_params(load_model((input_dim,))) + initial_parameters = ndarrays_to_parameters(params) - class FlowerClient(fl.client.NumPyClient): - """Flower client implementing using linear regression and JAX.""" - - def __init__( - self, - params: Dict, - grad_fn: Callable, - train_x: List[np.ndarray], - train_y: List[np.ndarray], - test_x: List[np.ndarray], - test_y: List[np.ndarray], - ) -> None: - self.params = params - self.grad_fn = grad_fn - self.train_x = train_x - self.train_y = train_y - self.test_x = test_x - self.test_y = test_y - - def get_parameters(self, config) -> Dict: - # Return model parameters as a list of NumPy ndarrays - parameter_value = [] - for _, val in self.params.items(): - parameter_value.append(np.array(val)) - return parameter_value - - def set_parameters(self, parameters: List[np.ndarray]) -> Dict: - # Collect model parameters and update the parameters of the local model - value = jnp.ndarray - params_item = list(zip(self.params.keys(), parameters)) - for item in params_item: - key = item[0] - value = item[1] - self.params[key] = value - return self.params - - def fit( - self, parameters: List[np.ndarray], config: Dict - ) -> Tuple[List[np.ndarray], int, Dict]: - # Set model parameters, train model, return updated model parameters - print("Start local training") - self.params = self.set_parameters(parameters) - self.params, loss, num_examples = jax_training.train( - self.params, self.grad_fn, self.train_x, self.train_y - ) - results = {"loss": float(loss)} - print("Training results", results) - return self.get_parameters(config={}), num_examples, results - - def evaluate( - self, parameters: List[np.ndarray], config: Dict - ) -> Tuple[float, int, Dict]: - # Set model parameters, evaluate the model on a local test dataset, return result - print("Start evaluation") - self.params = self.set_parameters(parameters) - loss, num_examples = jax_training.evaluation( - self.params, self.grad_fn, self.test_x, self.test_y - ) - print("Evaluation accuracy & loss", loss) - return ( - float(loss), - num_examples, - {"loss": float(loss)}, - ) + # Define strategy + strategy = FedAvg(initial_parameters=initial_parameters) + config = ServerConfig(num_rounds=num_rounds) -Having defined the federation process, we can run it. + return ServerAppComponents(strategy=strategy, config=config) -.. code-block:: python - def main() -> None: - """Load data, start MNISTClient.""" + # Create ServerApp + app = ServerApp(server_fn=server_fn) + +Congratulations! You've successfully built and run your first federated learning system +for JAX with Flower! + +.. note:: - # Load data - train_x, train_y, test_x, test_y = jax_training.load_data() - grad_fn = jax.grad(jax_training.loss_fn) + Check the source code of the extended version of this tutorial in + |quickstart_jax_link|_ in the Flower GitHub repository. - # Load model (from centralized training) and initialize parameters - model_shape = train_x.shape[1:] - params = jax_training.load_model(model_shape) +.. |client| replace:: ``Client`` - # Start Flower client - client = FlowerClient(params, grad_fn, train_x, train_y, test_x, test_y) - fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) +.. |fedavg| replace:: ``FedAvg`` +.. |makeregression| replace:: ``make_regression()`` - if __name__ == "__main__": - main() +.. |quickstart_jax_link| replace:: ``examples/quickstart-jax`` -And that's it. You can now open two additional terminal windows and run +.. |serverappcomponents| replace:: ``ServerAppComponents`` -.. code-block:: bash +.. _client: ref-api/flwr.client.Client.html#client - python3 client.py +.. _fedavg: ref-api/flwr.server.strategy.FedAvg.html#flwr.server.strategy.FedAvg -in each window (make sure that the server is still running before you do so) and see -your JAX project run federated learning across two clients. Congratulations! +.. _makeregression: https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html -Next Steps ----------- +.. _quickstart_jax_link: https://github.com/adap/flower/tree/main/examples/quickstart-jax -The source code of this example was improved over time and can be found here: -`Quickstart JAX `_. -Our example is somewhat over-simplified because both clients load the same dataset. +.. _serverappcomponents: ref-api/flwr.server.ServerAppComponents.html#serverappcomponents -You're now prepared to explore this topic further. How about using a more sophisticated -model or using a different dataset? How about adding more clients? +.. meta:: + :description: Check out this Federated Learning quickstart tutorial for using Flower with Jax to train a linear regression model on a scikit-learn dataset.