Federated Learning: A Simple Implementation of FedAvg (Federated Averaging) with PyTorch

Original article was published by eceisikpolat on Deep Learning on Medium


What would the performance of a centralized model that is based on all training data be?

First, let’s examine what would the performance of the centralized model be if the data were not distributed to nodes at all?

 — — — Centralized Model — — — 
epoch: 1 | train accuracy: 0.8743 | test accuracy: 0.9437
epoch: 2 | train accuracy: 0.9567 | test accuracy: 0.9654
epoch: 3 | train accuracy: 0.9712 | test accuracy: 0.9701
epoch: 4 | train accuracy: 0.9785 | test accuracy: 0.9738
epoch: 5 | train accuracy: 0.9834 | test accuracy: 0.9713
epoch: 6 | train accuracy: 0.9864 | test accuracy: 0.9768
epoch: 7 | train accuracy: 0.9898 | test accuracy: 0.9763
epoch: 8 | train accuracy: 0.9923 | test accuracy: 0.9804
epoch: 9 | train accuracy: 0.9941 | test accuracy: 0.9784
epoch: 10 | train accuracy: 0.9959 | test accuracy: 0.9792
— — — Training finished — — -

The model used in this example is very simple, different improvements can be performed to increase model performance, such as using more complex models, increasing epoch or hyperparameter tuning. However, the purpose here is to compare the performance of the main model that is formed by combining the parameters of the local models trained on their own data with a centralized model that trained on all training data. In this way, we can gain insight into the capacity of federated learning.

Then, start our first iteration

Data is distributed to nodes

The main model is created

Models, optimizers, and loss functions in nodes are defined

Keys of dicts are being made iterable

Parameters of the main model are sent to nodes
Since the parameters of the main model and parameters of all local models in the nodes are randomly initialized, all these parameters will be different from each other. For this reason, the main model sends its parameters to the nodes before the training of local models in the nodes begins. You can check the weights below.

Models in the nodes are trained

Let’s compare the performance of the federated main model and centralized model

Federated main model vs centralized model before 1st iteration (on all test data)
Since the main model is randomly initialized and no action taken on it yet, before first iteration its performance is very poor. After first iteration, the accuracy of main model increased to %85.

Before 1st iteration main model accuracy on all test data: 0.1180
After 1st iteration main model accuracy on all test data: 0.8529
Centralized model accuracy on all test data: 0.9790

This is a single iteration, we can send the parameters of the main model back to the nodes and repeat the above steps. Now let’s check how the performance of the main model improves when we repeat the iteration 10 more times.

Iteration 2 : main_model accuracy on all test data:  0.8928
Iteration 3 : main_model accuracy on all test data: 0.9073
Iteration 4 : main_model accuracy on all test data: 0.9150
Iteration 5 : main_model accuracy on all test data: 0.9209
Iteration 6 : main_model accuracy on all test data: 0.9273
Iteration 7 : main_model accuracy on all test data: 0.9321
Iteration 8 : main_model accuracy on all test data: 0.9358
Iteration 9 : main_model accuracy on all test data: 0.9382
Iteration 10 : main_model accuracy on all test data: 0.9411
Iteration 11 : main_model accuracy on all test data: 0.9431

The accuracy of the centralized model was calculated as approximately 98%. The accuracy of the main model obtained by FedAvg method started from 85% and improved to 94%. In this case, we can say that although the main model obtained by FedAvg method was trained without seeing the data, its performance cannot be underestimated.

You can visit the https://github.com/eceisik/fl_public/blob/master/fedavg_mnist_iid.ipynb to see the full implementation.

[1] J. Konečný, H. B. McMahan, D. Ramage, and P. Richtárik, “Federated Optimization: Distributed Machine Learning for On-Device Intelligence,” pp. 1–38, 2016.

[2] H. B. Mcmahan and D. Ramage, “Communication-Efficient Learning of Deep Networks from Decentralized Data,” vol. 54, 2017.

[3] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. “Gradient-based learning applied to document recognition.” Proceedings of the IEEE, 86(11):2278–2324, November 1998.