How To Build Your Own DeepMind MuZero In Python (Part 2/3)

Source: Deep Learning on Medium

Last time, we introduced MuZero and saw how it is different from its older brother, AlphaZero.

In the absence of the actual rules of chess, MuZero creates a new game inside its mind that it can control and uses this to plan into the future. The three networks (prediction, dynamics and representation) are optimised together so that strategies that perform well inside the imagined environment, also perform well in the real environment.

In this post, we’ll walk through the play_game function and see how MuZero makes a decision about the next best move at each turn. We’ll also explore the MCTS process in more detail.

Playing a game with MuZero (play_game)

We’re now going to step through the play_game function:

First, a new Game object is created and the main game loop is started. The game ends when a terminal condition has been met or if the number of moves is longer that the maximum allowed.

We start the MCTS tree with the root node.

root = Node(0)

Each Node stores key statistics relating to the number of times it has been visited visit_count, whose turn it is to_play, the predicted prior probability of choosing the action that leads to this node prior, the backfilled value sum of the node node_sum, its child nodes children, the hidden state it corresponds to hidden_state and the predicted reward received by moving to this node reward.

Next we ask the game to return the current observation (corresponding to o in the diagram above)…

current_observation = game.make_image(-1)

…and expand the root node using the known legal actions provided by the game and the inference about the current observation provided by the initial_inference function.

expand_node(root, game.to_play(), game.legal_actions(),network.initial_inference(current_observation))

We also need to add exploration noise to the root node actions — this is important to ensure that the MCTS explores a range of possible actions rather than only exploring the action which it currently believes to be optimal. For chess, root_dirichlet_alpha= 0.3.

add_exploration_noise(config, root)

We now hit the main MCTS process, which we will cover in the next section.

run_mcts(config, root, game.action_history(), network)

The Monte Carlo Search Tree in MuZero (run_mcts)

As MuZero has no knowledge of the environment rules, it also has no knowledge of the bounds on the rewards that it may receive throughout the learning process. The MinMaxStats object is created to store information on the current minimum and maximum rewards encountered so that MuZero can normalise its value output accordingly. Alternatively, this can also be initialised with known bounds for a game such as chess (-1, 1).

The main MCTS loop iterates over num_simulations, where one simulation is a pass through the MCTS tree until a leaf node (i.e. unexplored node) is reached and subsequent backpropagation. Let’s walk through one simulation now.

First, the history is initialised with the of list of actions taken so far from the start of the game. The current node is the root node and the search_path contains only the current node.

The simulation then proceeds as shown in the diagrams below: