Implementing Fast AlphaZero

While working at Cogito NTNU, I took on the challenge of implementing AlphaZero from scratch. What I thought would be a straightforward paper implementation turned into a multi-month journey of debugging and optimization. Along the way, I learned firsthand why research papers tend to gloss over the nitty-gritty implementation details.

How AlphaZero Actually Works

AlphaZero combines Monte Carlo Tree Search with a neural network that tries to predict two things: which moves are good, and who’s winning. The training loop is pretty straightforward:

  1. Self play to generate training data
  2. Train the neural network on that data
  3. Test if the new network is better than the old one

The neural network outputs a policy vector π (move probabilities) and a value estimate $v\in [-1,1]$ for how good the current position is.

Here’s what our network looked like:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn

class PolicyValueNet(nn.Module):
"""
Neural network that outputs both policy and value predictions.

Args:
board_size: Size of the game board (height, width)
action_size: Number of possible actions
num_channels: Number of channels in convolutional layers
"""
def __init__(self, board_size: tuple[int, int], action_size: int, num_channels: int = 256):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, num_channels, 3, padding=1),
nn.BatchNorm2d(num_channels),
nn.ReLU(),
nn.Conv2d(num_channels, num_channels, 3, padding=1),
nn.BatchNorm2d(num_channels),
nn.ReLU()
)

board_size_flat = board_size[0] * board_size[1]
self.policy_head = nn.Sequential(
nn.Linear(num_channels * board_size_flat, action_size),
nn.LogSoftmax(dim=1)
)

self.value_head = nn.Sequential(
nn.Linear(num_channels * board_size_flat, 1),
nn.Tanh()
)


Figure 1: A visualization of our MCTS implementation showing a Tic Tac Toe board with red and white circles (and one empty black position) on the left, and its corresponding search tree on the right displaying visit counts and move numbers at each node, starting from 500 visits at the root.

The Loss Function

The network learns by minimizing this loss:

$$L = (z - v)^2 - \pi^T \log p + c|\theta|^2$$

where $z$ is the actual game outcome, $v$ is what the network predicted, $\pi$ is what MCTS thinks the best moves are, $p$ is what the network thinks, and that last term just prevents overfitting.

Making It Actually Fast

The biggest lesson was that the algorithm is only half the battle. Getting good performance required a lot of engineering tricks that nobody talks about in papers.

Parallel Everything

Running MCTS simulations in parallel was huge for performance:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch.multiprocessing as mp

class ParallelMCTS:
"""
Parallel MCTS implementation using multiple processes.

Args:
num_processes: Number of parallel processes to use
batch_size: Size of evaluation batches
num_simulations: Number of MCTS simulations per move
"""
def __init__(self, num_processes: int = 8, batch_size: int = 32,
num_simulations: int = 800):
self.pool = mp.Pool(num_processes)
self.batch_size = batch_size
self.num_simulations = num_simulations

def run_batch_evaluation(self, positions: list[torch.Tensor]) -> list[tuple[torch.Tensor, float]]:
"""Evaluates a batch of positions using the neural network."""
with torch.no_grad():
policies, values = [], []
for i in range(0, len(positions), self.batch_size):
batch = positions[i:i + self.batch_size]
batch_policies, batch_values = self.network(torch.stack(batch))
policies.extend(batch_policies)
values.extend(batch_values)
return list(zip(policies, values))

Smart Caching

We also added caching to avoid evaluating the same positions over and over:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

class KVCache:
"""
Key value cache for MCTS node evaluations.

Args:
max_size: Maximum number of positions to cache
"""
def __init__(self, max_size: int = 100000):
self.cache: dict[str, tuple[np.ndarray, float]] = {}
self.max_size = max_size

def get_cached_value(self, board_hash: str) -> tuple[np.ndarray, float] | None:
return self.cache.get(board_hash)

What We Actually Achieved

After all the optimizations, we got some pretty solid improvements:

  • Self play got 3.2x faster
  • Network training sped up by 2.8x
  • Memory usage dropped by 45%

We also tried this thing called path consistency optimization, which forces the value predictions to be consistent along search paths:

$$L_{PC} = |f_v - \bar{f}_v|^2$$

It helped the network learn faster, though honestly I’m still not 100% sure why it works so well.

What I Actually Learned

Building this taught me that implementing research papers is way harder than it looks. The papers make everything sound clean and simple, but getting something that actually runs fast requires tons of boring engineering work that never gets mentioned.

The biggest time sink was debugging why our implementation was so much slower than reported results. Turns out a lot of the speedup comes from implementation details that are just assumed knowledge.

We put the code up on GitHub if you want to check it out. Fair warning though: getting it to run well takes some patience.