Source: Deep Learning on Medium
Double Deep Q learning introduction
This is second part of reinforcement learning tutorial series. In first tutorial we used simple method to train Deep Q Neural Network model to play CartPole balancing game. In this part our task will be the same, but this time we’ll make our environment to use two (Double) Neural Networks to train our main model. Adding to this we’ll implement ‘soft’ parameters update function to this, and finally we will compare results we got with each method.
This post will go into the original structure of the DQN agent presented in this article. We will implement additional enhancements written in article, and the results of testing the implementation in the CartPole v-1 environment will be presented. It is assumed that the reader has a basic familiarity with reinforcement or machine learning and (finished reading 1st tutorial).
The agent implemented in this tutorial follows the structure of the original DQN introduced in this paper, but is closer to what is known as a Double DQN. An enhanced version of the original DQN was introduced nine months later that utilized its internal Neural Networks slightly differently in its learning process. Technically the agent implemented in this tutorial isn’t really a DQN since our networks only contain a few layers (4) and so are not actually that “deep” but the distinguishing ideas of Deep Q Learning are still utilized.
An additional enhancement changing how the agent’s Neural Networks work with each other, known as “soft target network updates”, which was introduced in another Deep Q Learning paper, will be also tested in this tutorial code.
From my previous tutorial we already know the case of CartPole game, the state is represented by 4 values — cart position, cart velocity, pole angle, and the velocity of the tip of the pole — and the agent can take one of two actions at every step — moving left or moving right.
Double Deep Q learning
In Double Deep Q Learning, the agent uses two neural networks to learn and predict what action to take at every step. One network, referred to as the Q network or the online network, is used to predict what to do when the agent encounters a new state. It takes in the state as input and outputs Q values for the possible actions that could be taken. In the agent described in this tutorial, the online network takes in a vector of four values (state of the CartPole environment) as input and outputs a vector of two Q values, one for the value of moving left in the current state, and one for the value of moving right in the current state. The agent will choose the action that has the higher corresponding Q value output by the online network.
Double DQNs handles the problem of the overestimation of Q-values. By calculating the target with single deep network, we face a simple problem: how are we sure that the best action for the next state is the action with the highest Q-value?
We know that the accuracy of a Q values depends on what action we tried and what neighboring states we explored.
Consequently, at the beginning of the training we don’t have enough information about the best action to take. Therefore, taking the maximum q value (which is noisy) as the best action to take can lead to false positives. If non-optimal actions are regularly given a higher Q value than the optimal best action, the learning will be complicated.
The solution is: when we compute the Q target, we use two networks to decouple the action selection from the target Q value generation. We:
- use our DQN network to select what is the best action to take for the next state (the action with the highest Q value).
- use our target network to calculate the target Q value of taking that action at the next state.
Therefore, Double DQN helps us reduce the overestimation of Q values and helps us train faster and have more stable learning.
To understand Double DQN lets analyze practical example, with path of the agent through different states. The process of Double Q-Learning can be visualized in following graph:
An AI agent is at the start in state s. Agent based on some previous calculations knows the qualities Q(s, a1) and Q(s, a2) for possible two actions in that states. Agent decides to take action a1 and ends up in state s’.
The Q-Network calculates the qualities Q(s’, a1′) and Q(s, a2′) for possible actions in this new state. Action a1′ is picked because it results in the highest quality according to the Q-Network.
The new action-value Q(s, a1) for action a1 in state s can now be calculated with the equation in above figure, where Q(s’,a1′) is the evaluation of a1′ that is determined by the Target-Network.
Double Deep Q learning
All DQN agents learn and improve themselves through a method called experience replay, which is where in Double DQN the second network, called the target network comes into play. In order to carry out experience replay the agent “remembers” each step of the game after it happens, each memory consisting of the action taken, the state that was in place when the action was taken, the reward given from taking the action, and the state that resulted from taking the action. These memories are also known as experiences. After each step is taken in the episode, the experience replay procedure is carried out (after enough memories have been stored) and consists of the following steps:
- a random set of experiences (called a minibatch) is retrieved from the agent’s memory
- for each experience in the minibatch, new Q values are calculated for each state and action pair — if the action ended the episode, the Q value will be negative (bad) and if the action did not end the game (i.e. kept the agent alive for at least one more turn), the Q value will be positive and is predicted by what is called a Bellman equation. The general formula of the Bellman equation used by the agent implemented here is:
value = reward + discount_factor * target_network.predict(next_state[argmax(online_network.predict(next_state))]
- the NN is fit to associate each state in the minibatch with the new Q values calculated for the actions taken
Below is experience replay method used in code with short explanations:
Why are two networks needed to generate the new Q value for each action? The single online network could be used to generate the Q value to update, but if it did, then each update would consist of the single online network updating its weights to better predict what it itself outputs — the agent is trying to fit to a target value that it itself defines and this can result in the network quickly updating itself too drastically in an unproductive way. To avoid this situation, the Q values to update are taken from the output of the second target network which is meant to reflect the state of the online network but does not hold identical values.
From above code you can see, that there is line if self.ddqn:, the code is written in a way that we would need to change one defined variable to False and we’ll be using standard DQN, this will help us to compare different results of these models.
What makes this network a Double DQN?
The Bellman equation used to calculate the Q values to update the online network follows the equation:
value = reward + discount_factor * target_network.predict(next_state)[argmax(online_network.predict(next_state))]
The Bellman equation used to calculate the Q value updates in the original DQN is:
value = reward + discount_factor * max(target_network.predict(next_state))
The difference is that, using the terminology of the field, the second equation uses the target network for both SELECTING and EVALUATING the action to take whereas the first equation uses the online network for SELECTING the action to take and the target network for EVALUATING the action. Selection here means choosing which action to take, and evaluation means getting the projected Q value for that action. This form of the Bellman equation is what makes this agent a Double DQN and not just a DQN.
Soft Target Network Update
The method used to update the target network’s weights in the original DQN paper is to set them equal to the weights of the online network every fixed number of steps. Another established method, introduced here suggests updating the target network weights incrementally, this means that our target network weights should reflect the online network weights after every run of experience replay with the following formula:
target_weights = target_weights * (1-TAU) + q_weights * TAU where 0 < TAU < 1
If you don’t understand code here, don’t worry. It’s much easier to understand everything when you get deeper into full code. But simply talking, if we would use TAU as 0.1, then we would get result as target_weights=target_weights*0.9+q_weights*0.1
This means that we are updating only 10% of new weights and we use 90% old weights.
By running through experience replay every time the agent takes an action, and updating the parameters of the online network, the online network will begin to associate certain state/action pairs with appropriate Q values — the greater promise there is for taking a certain action at a certain state, the model will begin to predict higher Q values, and will start to survive for longer as the agent keeps playing the game.
So in full code implementations we wrote few simple functions to track our scores and plot results in graph for better visual comparison. Here are the following functions:
At the end of this tutorial is full code. With this code we did 3 different experiments, where:
1. self.Soft_Update = False and self.ddqn = False
2. self.Soft_Update = False and self.ddqn = True
3. self.Soft_Update = True and self.ddqn = True
That our experiment would not be too long, we defined maximum episode steps to train as 1000: self.EPISODES = 1000
First test was: with self.Soft_Update = False and self.ddqn = False, so with these parameters we were using standard DQN network without soft update. From following graph results we can see that learning wasn’t stable and always we were receiving random spikes, but if we would try this model in test mode, it would perform the same way (constant low results with spikes), but our goal is stable solving.