Federated Learning Demo in Python: Training Models using Federated Learning (Part 3)

Original article was published by Ahmed Gad on Artificial Intelligence on Medium


Receiving the Population at the Client

Using the pickle library, the client receives and decodes the GANN instance, as it did previously with the text messages. The code that does that is provided below.

As long as there’s data sent to the client, the client receives this data using the recv() method. Once all data is received, the bytes data is decoded using the pickle.loads() method.

received_data = b''
while str(received_data)[-2] != '.':
data = soc.recv(1024)
received_data += data

received_data = pickle.loads(received_data)

Remember that the server sent a dictionary to the client. As a result, the received_data variable holds a dictionary. Its content is printed below.

print(received_data){'subject': 'model', 'data': <pygad.gann.gann.GANN at 0x23de2f22208>}

The subject key can have its associated value set to either model or done. If the dictionary item with the key subject is set to model, this means the dictionary has an item with key data, which holds the model. For done, this means the model is trained.

subject = received_data["subject"]
if subject == "model":
GANN_instance = received_data["data"]
print(GANN_instance)
elif subject == "done":
print("Model is trained.")
else:
print("Unrecognized message type.")
<pygad.gann.gann.GANN object at 0x0000023DE2F22208>

The networks in the population can be returned using the population_networks attribute.

print(GANN_instance.population_networks)

The pygad.gann.population_as_vectors() function can also be used to return the parameters of all networks in the population.

population_vectors = pygad.gann.population_as_vectors(population_networks=GANN_instance.population_networks)

After receiving the population of neural networks at the server, the next section discusses using the genetic algorithm to evolve the solutions’ parameters.

Training the Network using the Genetic Algorithm

To optimize a problem using PyGAD (e.g. train a neural network), then an instance of the pygad.GA class is needed. The constructor of this class accepts many parameters. I’d recommend checking out the documentation of this class.

To create an instance of the pygad.GA class, the following code snippet uses a function named prepare_GA(). It accepts the received dictionary from the server and returns an instance of the pygad.GA class.

The initial_population parameter accepts the parameters of all networks in the population, which are returned using the pygad.gann.population_as_vectors() function. This is how the pygad.GA instance is linked to the pygad.gann.GANN instance.

There are 2 important parameters to be prepared— fitness_func and callback_generation.

The fitness_func parameter accepts a function that calculates the fitness value for each solution. In this example, it calculates the classification accuracy of each solution in the population. The passed function to the fitness_func parameter must accept 2 parameters. The first one represents the solution to calculate its fitness, and the second one is the index of that solution in the population.

The callback_generation parameter accepts a function that’s called after each generation. In this example, it’s used to update the parameters of all networks used in the population, using the update_population_trained_weights() method.

Note that the client only has 2 samples of the XOR data. The other 2 samples are available in another client.

The next code snippet calls the prepare_GA() function to create the pygad.GA instance. After being created, the run() method can be called to run the genetic algorithm. The plot_result() method shows how fitness values change by generation. The best_solution() method returns information about the best solution in the population.

ga_instance = prepare_GA(GANN_instance)

ga_instance.run()

ga_instance.plot_result()

best_sol_idx = ga_instance.best_solution()[2]

Sending the Client’s Trained Network Back to the Server

After the model is trained at the client, the client then sends the model back to the server, along with some more information within a dictionary. There are 3 keys in the dictionary:

  1. subject: If set to model, then the client is sending the model to the server. If set to echo, then the server will forward the client’s message back to it.
  2. data: The message data. If the subject is set to model, then it holds the instance of the pygad.gann.GANN class after updating the parameters of the solutions (i.e. networks).
  3. best_solution_idx: The index of the best solution in the population.
data = {"subject": subject, "data": GANN_instance, "best_solution_idx": best_sol_idx}
data_byte = pickle.dumps(data)

print("Sending the Model to the Server.\n")
soc.sendall(data_byte)

Final Client Code

The final complete code of the client is listed below. The code is available at GitHub in the TutorialProject/Part3/client1.py and TutorialProject/Part3/client2.py files under the Federated Learning GitHub project.

The behavior of the client is as follows:

  • Send a message to the server
  • Receive the model from the server
  • Run the genetic algorithm to optimize the network parameters.
  • The client works until it receives a message in which the subject key is set to done.

Aggregating Models’ Parameters at the Server

Once a client sends its locally-trained version of the model, then the server uses its parameters and aggregates them with the current parameters, returning a new version of the model. This new version is then sent to the clients, and the process is repeated.

At the server, the model is received using the reply() method. This model is aggregated with the current model and accepts the model sent by the client, aggregates its individual parameters with the current global parameters, and replies to the client.

The reply() method is provided below. The method uses 4 global variables:

  1. GANN_instance: The model received from the client.
  2. data_inputs: The test data inputs.
  3. data_outputs: The test data outputs.
  4. model: The most recent model updated after each client sends its trained model. It’s initialized to None.

The method checks the subject key in the dictionary. If the subject is echo, then it just replies with the same message received from the client. When it’s model, then it extracts the client’s trained model, aggregates its parameters with the current model’s parameters using the model_averaging() method, and calculates the error of the updated model.

If the error of the model is 0 (i.e. classification accuracy is 100), then it replies with a dictionary in which the subject key is set to done to inform the client that training has ended. Otherwise, it sets the subject to model and the data key to the current model.

The reply() method is called from the thread’s run() method.

def run(self):
...
while True:
...
self.reply(received_data)

The implementation of the model_averaging() method is shown below. It accepts the client’s model in addition to the current model available at the server, and returns a new model with its parameters set to the average of the parameters of the 2 models.

Final Server Code

The complete code of the server is listed below (TutorialProject/Part3/server.py file under the Federated Learning GitHub project).

The behavior of the server is as follows:

  • Create a connection with a client.
  • Receive data from the client.
  • Send the model to the client.
  • Keep sending & receiving the model from the client.
  • After reaching the desired accuracy/error, send a message in which the subject key is set to done.

Running the Server and the Clients

After completing the implementation of the server and the client, the next step is to run both the server and the clients. It’s expected that there awill be multiple clients, where each client has some training samples.

Simply run the server, and then run the clients.

The server code is available here.

The code of the first client is here. This client has the 2 samples given below:

data_inputs = numpy.array([[0, 1],
[0, 0]])

data_outputs = numpy.array([1,
0])

The code of the first client is here. This client has the 2 samples given below:

data_inputs = numpy.array([[1, 0],
[1, 1]])

data_outputs = numpy.array([1,
0])

Note that you can close a client’s connection—or create new clients and connect them to the server—at any time.

In the server, there are print statements that help us figure out the state of the server. After running the server and 2 clients, here are the outputs of the server’s print statements.

When the prediction error at the server is 0.0, the server replies to the clients with messages in which the subject key is set to done.

Socket Created.

Socket Bound to IPv4 Address & Port Number.

Socket is Listening for Connections ....

New Connection from ('127.0.0.1', 51487).

Running a Thread for the Connection with ('127.0.0.1', 51487).

Waiting to Receive Data Starting from 26/6/2020 21:25:29 GMT
All data (74 bytes) Received from ('127.0.0.1', 51487).
Client's Message Subject is echo.
Replying to the Client.

Waiting to Receive Data Starting from 26/6/2020 21:25:29 GMT
All data (2836 bytes) Received from ('127.0.0.1', 51487).
Model Predictions: [1. 1. 1. 0.]
Error = 1.0

Waiting to Receive Data Starting from 26/6/2020 21:25:30 GMT

New Connection from ('127.0.0.1', 51490).

Running a Thread for the Connection with ('127.0.0.1', 51490).

Waiting to Receive Data Starting from 26/6/2020 21:25:31 GMT
All data (74 bytes) Received from ('127.0.0.1', 51490).

Waiting to Receive Data Starting from 26/6/2020 21:25:31 GMT
All data (2836 bytes) Received from ('127.0.0.1', 51487).
Model Predictions: [1. 1. 1. 0.]
Error = 1.0

Waiting to Receive Data Starting from 26/6/2020 21:25:32 GMT
All data (2836 bytes) Received from ('127.0.0.1', 51490).
Model Predictions: [0. 1. 0. 0.]
Error = 1.0

Waiting to Receive Data Starting from 26/6/2020 21:25:32 GMT
All data (2836 bytes) Received from ('127.0.0.1', 51487).
Model Predictions: [0. 0. 0. 0.]
Error = 2.0

Waiting to Receive Data Starting from 26/6/2020 21:25:33 GMT
All data (2836 bytes) Received from ('127.0.0.1', 51490).
Model Predictions: [0. 1. 0. 0.]
Error = 1.0

Waiting to Receive Data Starting from 26/6/2020 21:25:34 GMT
All data (2836 bytes) Received from ('127.0.0.1', 51487).
Model Predictions: [0. 1. 0. 0.]
Error = 1.0

Waiting to Receive Data Starting from 26/6/2020 21:25:35 GMT
All data (2836 bytes) Received from ('127.0.0.1', 51490).
Model Predictions: [1. 1. 0. 0.]
Error = 2.0

Waiting to Receive Data Starting from 26/6/2020 21:25:35 GMT
All data (2836 bytes) Received from ('127.0.0.1', 51487).
Model Predictions: [1. 1. 1. 0.]
Error = 1.0

Waiting to Receive Data Starting from 26/6/2020 21:25:36 GMT
All data (2836 bytes) Received from ('127.0.0.1', 51490).
Model Predictions: [1. 1. 0. 0.]
Error = 2.0

Waiting to Receive Data Starting from 26/6/2020 21:25:37 GMT
All data (2836 bytes) Received from ('127.0.0.1', 51487).
Model Predictions: [0. 0. 0. 0.]
Error = 2.0

Waiting to Receive Data Starting from 26/6/2020 21:25:38 GMT
All data (2836 bytes) Received from ('127.0.0.1', 51490).
Model Predictions: [0. 0. 1. 0.]
Error = 1.0

Waiting to Receive Data Starting from 26/6/2020 21:25:39 GMT
All data (2836 bytes) Received from ('127.0.0.1', 51487).
Model Predictions: [0. 1. 1. 0.]
Error = 0.0

Waiting to Receive Data Starting from 26/6/2020 21:25:40 GMT
All data (2836 bytes) Received from ('127.0.0.1', 51490).

Waiting to Receive Data Starting from 26/6/2020 21:25:40 GMT

Connection Closed with ('127.0.0.1', 51487) either due to inactivity for 10 seconds or due to an error.

Connection Closed with ('127.0.0.1', 51490) either due to inactivity for 10 seconds or due to an error.

Conclusion

In part 3 of our federated learning demo project in Python, the client-server socket application was extended to implement the concepts of federated learning. The code is available at GitHub.

The achievements of this tutorial are as follows:

  • The server creates a generic, non-trained model using PyGAD.
  • Once a client gets connected to the server, the server responds with the most recent version of the model.
  • The client trains the model on its local data using the genetic algorithm implemented in PyGAD.
  • The client sends the trained model to the server.
  • The server aggregates the parameters of its current model and the client’s model.
  • The server tests the model using its test data and calculates the accuracy/error.
  • The process of sending and receiving the model between the server and the client continues until the desired accuracy/error is reached.

In the next tutorial, we’ll use the Kivy framework to implement GUI for both the server and the client. Using the python-4-android and Buildozer projects, we’ll then build mobile apps for both Android and iOS, which will ultimately serve as the clients in this federated learning pipeline.