How To Build Your Own DeepMind MuZero In Python (Part 3/3)

Source: Deep Learning on Medium

Now that we have seen how the targets are built, we can see how they fit in to the MuZero loss function and finally, see how they are used in the update_weights function to train the networks.

The MuZero loss function

The loss function for Muzero is as follows:

Here, K is the num_unroll_steps variable. In other words, there are three losses we are trying to minimise:

  1. The difference between the predicted reward k steps ahead of turn t (r) and the actual reward (u)
  2. The difference between the predicted value k steps ahead of turn t (v) and the TD target value (z)
  3. The difference between the predicted policy k steps ahead of turn t (p) and the MCTS policy(pi)

These losses are summed over the rollout to generate the loss for a given position in the batch. There is also a regularisation term to penalise large weights in the network.

Updating the three MuZero networks (update_weights)

The update_weights function builds the loss piece by piece, for each of the 2048 positions in the batch.

First the initial observation is passed through the initial_inference network to predict the value, reward and policy from the current position. These are used to create the predictions list, alongside a given weighting of 1.0.

Then, each action is looped over in turn and the recurrent_inference function is asked to predict the next value, reward and policy from the current hidden_state. These are appended to the predictions list with a weighting of 1/num_rollout_steps (so that the overall weighting of the recurrent_inference function is equal to that of the initial_inference function).

We then calculate the loss that compares the predictions to their corresponding target values — this is a combination of scalar_loss for the reward and value and softmax_crossentropy_loss_with_logits for the policy.

The optimises then uses this loss function to train all three of the MuZero networks simultaneously.

So…that’s how you train MuZero using Python.


In summary, AlphaZero is hard-wired to know three things:

  • What happens to the board when it makes a given move. For example, if it takes the action ‘move pawn from e2 to e4’ it knows that the next board position is the same, except the pawn will have moved.
  • What the legal moves are in a given position. For example, AlphaZero knows that you can’t move ‘queen to c3’ if your queen is off the board, a piece is blocking the move, or you already have a piece on c3.
  • When the game is over and who won. For example, it knows that if the opponent’s king is in check and cannot move out of check, it has won.

In other words, AlphaZero can imagine possible futures because it knows the rules of the game.

MuZero is denied access to these fundamental game mechanics throughout the training process. What is remarkable is that by adding a couple of extra neural networks, it is able to cope with not knowing the rules.

In fact, it flourishes.

Incredibly, MuZero actually improves on AlphaZero’s performance in Go. This may indicate that it is finding more efficient ways to represent a position through its hidden representation than AlphaZero can find when using the actual board positions. The mysterious ways in which MuZero is embedding the game in its own mind will surely be an active area of research for DeepMind in the near future.