Original article was published on Deep Learning on Medium
In this series, CIFAR 10 is used as the benchmark dataset, and further, it is converted into a non-IID dataset. To learn more about the basics of federated learning, please head over to part 1 of this series. In this tutorial, we will create two different types of the dataset, one is replicating the real-life data, i.e. real-world dataset, and another one is the extreme example of a non-iid dataset.
REAL-WORLD DATASET: CIFAR 10 is randomly divided into the given number of clients. So, a client can have images from any number of classes, say, one client has images from only 1 class and another client has images from 5 classes. This type of dataset replicates the real-world scenario where clients can have different types of images.
NON-IID DATASET: The real-world dataset also comes into this category. However, for this case, a different type of distribution is done where the CIFAR10 dataset is divided based on the classes per client parameter. E.g., if classes per client=2, all the clients will have images from any two classes randomly selected from all the classes.
From the above image, one can get to know that all the clients have images from any two classes, but some outliers exist where a few clients have images from more than two classes, while some have very less number of images (client_20) when compared with the rest. Such outliers ensure that all the clients have a unique set of images. Thus an extreme example of a non-IID dataset. We can also divide the CIFAR10 dataset with classes per client = 1, where all the clients will have images from one class only with some outliers.
Let’s dive deep into the coding part to understand it in more detail.
1. Importing the libraries
- classes_pc: classes per client, it is used to divide the balanced dataset to non-IID dataset by creating an unbalanced representation of classes among the clients. For e.g., if the classes_pc=1, then all the clients will have images from one class only, thus creating an extensive imbalance among the clients. (Ref: Figure 2 )
- num_clients: Total number of clients among which images are to be distributed.
- batch_size: Loading of the data into the data loader by batches.
- real_wd: We are creating two types of datasets, one is the real-world dataset (Figure 1) and another is the extreme non-IID dataset (Figure 2). If real_wd is TRUE then dataset replicating real-life is created, i.e. real-world dataset (figure 1). If real_wd is FALSE (by default) then the extreme non-IID dataset is created.
3. Creating the distribution
The get_cifar10 function downloads the CIFAR10 dataset and returns x_train, y_train for training, and x_test, y_test for test purposes. Lines 5–6 downloads the dataset from torchvision, and lines 8–9 converts it into NumPy array.
The clients_rand function creates a random distribution for the clients, such that every client has an arbitrary number of images. It is one of the helper functions to be used in the upcoming code snippets.
The split_image_data_realwd function splits the given images into n_clients. It returns a split which is further used to create the real-world dataset. This entire snippet has comments to explain what is happening in the sequential order. To understand more about this distribution, please head over to Part 2 of this series.
Now, we will create a similar split for the non-IID dataset, as explained above (with a diagram). To understand more about this distribution, please head over to Part 2 of this series.
The shuffle_list function takes the input of the above function(s) (split_image_data_realwd or split_image_data) and shuffles the images of each client respectively.
The below code snippet converts the split into a data loader(image augmentation is done is this part) for giving this as an input to the model for training.
The get_data_loader function uses the above helper functions and converts the CIFAR10 dataset into real-world or non-IID type, whichever is required.
Now, we are ready with all the functions to create a real-world or non-IID dataset, which can be further used by federated learning to develop state of the art models.