Source: Deep Learning on Medium
Last time, we introduced MuZero and saw how it is different from its older brother, AlphaZero.
In the absence of the actual rules of chess, MuZero creates a new game inside its mind that it can control and uses this to plan into the future. The three networks (prediction, dynamics and representation) are optimised together so that strategies that perform well inside the imagined environment, also perform well in the real environment.
In this post, we’ll walk through the
play_game function and see how MuZero makes a decision about the next best move at each turn. We’ll also explore the MCTS process in more detail.
Playing a game with MuZero (play_game)
We’re now going to step through the
First, a new
Game object is created and the main game loop is started. The game ends when a terminal condition has been met or if the number of moves is longer that the maximum allowed.
We start the MCTS tree with the root node.
root = Node(0)
Each Node stores key statistics relating to the number of times it has been visited
visit_count, whose turn it is
to_play, the predicted prior probability of choosing the action that leads to this node
prior, the backfilled value sum of the node
node_sum, its child nodes
children, the hidden state it corresponds to
hidden_state and the predicted reward received by moving to this node
Next we ask the game to return the current observation (corresponding to
o in the diagram above)…
current_observation = game.make_image(-1)
…and expand the root node using the known legal actions provided by the game and the inference about the current observation provided by the
expand_node(root, game.to_play(), game.legal_actions(),network.initial_inference(current_observation))
We also need to add exploration noise to the root node actions — this is important to ensure that the MCTS explores a range of possible actions rather than only exploring the action which it currently believes to be optimal. For chess,
We now hit the main MCTS process, which we will cover in the next section.
run_mcts(config, root, game.action_history(), network)
The Monte Carlo Search Tree in MuZero (run_mcts)
As MuZero has no knowledge of the environment rules, it also has no knowledge of the bounds on the rewards that it may receive throughout the learning process. The
MinMaxStats object is created to store information on the current minimum and maximum rewards encountered so that MuZero can normalise its value output accordingly. Alternatively, this can also be initialised with known bounds for a game such as chess (-1, 1).
The main MCTS loop iterates over
num_simulations, where one simulation is a pass through the MCTS tree until a leaf node (i.e. unexplored node) is reached and subsequent backpropagation. Let’s walk through one simulation now.
history is initialised with the of list of actions taken so far from the start of the game. The current
node is the
root node and the
search_path contains only the current node.
The simulation then proceeds as shown in the diagrams below: