Source: Deep Learning on Medium
Introduction to Prioritized Experience Replay
So, in our previous tutorial we implemented Double Dueling DQN Network model, and we saw that this way our agent improved slightly. Now it’s time to implement Prioritized Experience Replay (PER) which was introduced in 2015 by Tom Schaul. Paper idea is that some experiences may be more important than others for our training, but might occur less frequently.
Because we sample the batch uniformly (selecting the experiences randomly) these rich experiences that occur rarely have practically no chance to be selected.
That’s why, with PER, we will try to change the sampling distribution by using a criterion to define the priority of each tuple of experience.
So, we want to take in priority experience where there is a big difference between our prediction and the TD target, since it means that we have a lot to learn about it.
We’ll use the absolute value of the magnitude of our TD error:
– |δt| — Magnitude of our TD error
– e — constant assures that no experience has 0 probability to be taken
Then we’ll put priority to the experience of each replay buffer:
But we can’t do greedy prioritization, because it will lead to always training the same experiences (that have big priority), and then we’ll be over-fitting our agent. So, we will use stochastic prioritization, which generates the probability of being chosen for a replay:
– pi — Priority value
– ∑kpak — Normalization by all priority values in Replay Buffer
– a — Hyperparameter used to reintroduce some randomness in the experience selection for the replay buffer (if a=0 pure randomness, if a=1 only selects the experience with the highest priorities)
As a consequence, during each time step, we will get a batch of samples with batch probability distribution and we’ll train our network on it. But we still have a problem here. With normal Experience Replay (deque) buffer, we use a stochastic update rule. Therefore, the way we sample the experiences must match the underlying distribution they came from.
When we have normal(deque) experience, we select our experiences in a normal distribution — simply put and then we select our experiences randomly. There is no bias, because each experience has the same chance to be taken, so we can update our weights normally. But, because we use priority sampling, purely random sampling is abandoned. Therefore, we introduce bias toward high-priority samples here (they have more chances to be selected).
If we would update our weights normally, we have a risk of over-fitting our agent. Samples that have high priority are likely to be used for training many times in comparison with low priority experiences (= bias). So, as a consequence, we’ll update our weights with only a small portion of experiences that we consider to be really interesting.
To correct this bias, we’ll use importance sampling weights (IS) that will adjust the updating by reducing the weights of the often seen samples:
– N — Replay Buffer Size
– P(i) — Sampling probability
The weights corresponding to high-priority samples have very little adjustment (because the network will see these experiences many times), whereas those corresponding to low-priority samples will have a full update.
The role of bias is to control how much these importance sampling weights affect learning. In practice, the bias parameter is annealed up to 1 over the duration of training, because these weights are more important in the end of learning when our q values begin to converge. The unbiased nature of updates is most important near convergence.
This time, the implementation will be a little bit fancier than before.
First of all, we can’t just implement PER by sorting all the Experience Replay Buffers according to their priorities. This will not be efficient at all due to O(nlogn) for insertion and O(n) for sampling.
As explained in this article we need to use another data structure instead of sorting an array — an unsorted SumTree.
A SumTreeis a Binary Tree, that is a tree with only a maximum of two children for each node. The leaves (deepest nodes) contain the priority values, and a data array that points to leaves contains the experiences:
Then, we create a memory object that will contain our SumTree and data.
Next, to sample a minibatch of size k, the range [0, total_priority] will be divided into k ranges. A value is uniformly sampled from each range.
Finally, the transitions (experiences) that correspond to each of these sampled values are retrieved from the SumTree.
I will use Morvan Zhou SumTree code from this link. So, first we create a SumTree object class:
First, we want to build a tree with all nodes = 0 and initialize the data with all values = 0. So, we define number of leaf nodes (final nodes) that contains experiences. Next, with self.tree = np.zeros(2 * capacity — 1) line we generate the tree with all nodes values = 0. To understand this calculation (2 * capacity — 1) look at the schema bellow:
/ \ / \
0 0 0 0
Here we are in a binary node (each node has max 2 children) so 2x size of leaf (capacity) — 1 (root node). So, to calculate all nodes: Parent nodes = capacity — 1 and Leaf nodes = capacity. Finally, we define our data that contains the experiences (so the size of data is capacity).
Second, we define add function that will add our priority score in the SumTree leaf and add the experience in data:
While putting new data to our tree we fill the leaves from left to right, so first what we do is we look at what index we want to put the experience:
tree_index = self.data_pointer + self.capacity — 1
this is how our tree will look like, while we start filling it:
/ \ / \
tree_index 0 0 0
so, while adding new information to our tree we are doing 3 steps:
- Update data frame: self.data[self.data_pointer] = data
- We update the leaf: self.update (tree_index, priority) — this function will be created later
- And we shift our pointer to right by one: self.data_pointer += 1.
If we reach the capacity limit, we go back to first index (we overwrite) again.
As I said, next we create function to update the leaf priority score and propagate the change through tree:
In update function, first what we do is, we calculate priority change, from our new priority we subtract our previous priority score, and we overwrite our previous priority with new priority. After that, we propagate the change through the tree in a while loop.
Here is how our tree looks with 6 leaf’s:
/ \ / \
3 4 5 
The numbers in this tree are the indexes not the priority values, so here we want to access the line above the leaf’s. So for example: If we are in a leaf at index 6, we updated the priority score, we need then to update index 2 node:
tree_index = (tree_index — 1) // 2
tree_index = (6–1)//2
tree_index = 2 # (because of // we round the result)
last step is to update our tree leaf with calculated change:
self.tree += change
Next, we must build a function to get a leaf from our tree. So, we’ll build a function to get the leaf_index, priority value of that leaf and experience associated with that leaf index:
To understand what we are doing, let’s look at our tree from index perspective:
0 -> storing priority sum
/ \ / \
3 4 5 6 -> storing priority for experiences
Here we are looping our code in a while loop. First thing we do, we find our left and right child indexes. We keep repeating the action to find our leaf until we find it. When we know our parent leaf index, we calculate our data index, and finally we return our leaf index, our leaf priority and data stored in according leaf index.
At the end I also wrote `total_priority` function, this function will be used to return the root node.
Now we finished constructing our SumTree object, next we’ll build a memory object. Writing this tutorial I relied on code from this link. So same as before we’ll create Memory object:
Here we defined 3 hyperparameters:
- PER_e, hyperparameter that we use to avoid some experiences to have 0 probability of being taken
- PER_a, hyperparameter that we use to make a tradeoff between taking only experience with high priority and sampling randomly
- PER_b, importance-sampling, from initial value increasing to 1
Before, we created a tree function which is composed of a SumTree that contains the priority scores at his leaf and data in array. Now, differently from our previous tutorials, we won’t use deque(), because at each time-step our experiences index changes by one. We prefer to use a simple array and to overwrite it when our memory is full.
Next, we define a function to store a new experience in our tree. Each new experience will have a score of max_prority (it will be then improved when we use this experience to train our agent). Experience, f. e. in cartpole game would be (state, action, reward, next_state, done). So, we are defining our store function:
So, here we search for max priority in our leaf nodes that contains experiences, if we can’t find any priority in our tree, we set a max priority as absolute_error_upper, in our case 1. Then we store this priority and experience to our memory tree. Else wise we store our experience with maximum priority we can find.
Next we are creating sample function, which will be used to pick batch from our tree memory, which will be used to train our model. First, we sample a minibatch of n size, the range [0, priority_total] into priority ranges. Then a value is uniformly sampled from each range. Then we search in the SumTree, for the experience where priority score corresponds to retrieved sample values.
And finally, we create a function, to update the priorities on the tree:
Now we finished our Memory and SumTree classes, don’t worry, everything is uploaded to GitHub, so you could download this code! Above code is in PER.py script.
Now we can continue on our main agent code, we’ll modify it that we could use deque() and prioritized memory replay with simple Boolean function, this will help us to check difference in results.
Agent with Prioritized Experience Replay
So, now we know how our Prioritized Experienced Replay memory works, so I stored our created SumTree and Memory object classes to PER.py script. We’ll import them with following new line: from PER import *.
In our DQN Agent initialization we create an object self.MEMORY = Memory(memory_size) with memory_size = 10000. So while we will be using PER memory instead of self.memory = deque(maxlen=2000) we’ll use self.MEMORY. And to easily control, if we want to use PER or not to use it, we’ll insert self.USE_PER = True Boolean command.
Note: all this “memory_size = 10000” is stored in memory, if you have too large number here, you may get out of memory. So, while implementing PER almost all the functions stay the same as before, only few of them changes a little. For example, now remember function will look like this:
def remember(self, state, action, reward, next_state, done):
experience = state, action, reward, next_state, done
As I already said, here with Boolean operation, we choose what memory type our agent will use.
More will change our replay function (the way we sample our minibatches), we take them from PER memory or from deq list. If we take our minibatches from PER, we must recalculate absolute_errors and update our memory with it.
Our run function doesn’t change. From this point you should download code from GitHub link.
Now, let’s look at two examples on same CartPole balancing game, where I trained our agent for 1000 steps. I trained two examples:
- One with PER enabled
- Another with PER disabled
Both agents were trained with double dueling Deep Q Network, epsilon greedy update and soft update disabled.
First let’s look at our results, where we were training our agent without PER, results look very similar as they were in my previous tutorial. Best average score our agent could hit was around 1060: