Implementing Fast AlphaZero

While I was at Cogito NTNU, I decided to implement AlphaZero from scratch with a few teammates. I assumed this would be a few weekends of work, since the paper looks clean and the algorithm fits on a single slide. It ended up being months of debugging and optimization. Research papers tend to skip over the implementation details that turn out to be most of the actual work.

How AlphaZero works

AlphaZero combines Monte Carlo tree search with a neural network that predicts two things: which moves are good (a policy) and who’s winning the position (a value). The training loop is conceptually simple:

  1. Self-play to generate training data.
  2. Train the neural network on that data.
  3. Test whether the new network beats the previous one.

The network outputs a policy vector \(\pi\) over moves and a value \(v \in [-1, 1]\) for how good the position looks. The policy and value heads share a convolutional trunk:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn

class PolicyValueNet(nn.Module):
def __init__(self, board_size: tuple[int, int], action_size: int, num_channels: int = 256):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, num_channels, 3, padding=1),
nn.BatchNorm2d(num_channels),
nn.ReLU(),
nn.Conv2d(num_channels, num_channels, 3, padding=1),
nn.BatchNorm2d(num_channels),
nn.ReLU()
)

board_size_flat = board_size[0] * board_size[1]
self.policy_head = nn.Sequential(
nn.Linear(num_channels * board_size_flat, action_size),
nn.LogSoftmax(dim=1)
)
self.value_head = nn.Sequential(
nn.Linear(num_channels * board_size_flat, 1),
nn.Tanh()
)

A snapshot of our MCTS implementation on a Tic-Tac-Toe board, with the search tree on the right showing visit counts per node, starting from 500 at the root.

The loss function

The network learns by minimizing a single combined loss:

\[L = (z - v)^2 - \pi^T \log p + c\|\theta\|^2\]

where \(z\) is the actual game outcome, \(v\) is the network’s predicted value, \(\pi\) is the move distribution from MCTS, \(p\) is the network’s predicted policy, and the last term is L2 regularization to keep the weights from running away.

Making it fast

Getting the algorithm right was about half the work. Getting it to run at a useful speed required engineering tricks the paper doesn’t bother with.

The biggest single win was running MCTS simulations in parallel and batching the network evaluations:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch.multiprocessing as mp

class ParallelMCTS:
def __init__(self, num_processes: int = 8, batch_size: int = 32, num_simulations: int = 800):
self.pool = mp.Pool(num_processes)
self.batch_size = batch_size
self.num_simulations = num_simulations

def run_batch_evaluation(self, positions):
with torch.no_grad():
policies, values = [], []
for i in range(0, len(positions), self.batch_size):
batch = positions[i:i + self.batch_size]
batch_policies, batch_values = self.network(torch.stack(batch))
policies.extend(batch_policies)
values.extend(batch_values)
return list(zip(policies, values))

The network is the bottleneck, and sending one position at a time wastes most of the GPU. Batching turned a slow training loop into a usable one.

The other obvious win was caching. Self-play games revisit the same positions a lot, especially in the opening:

1
2
3
4
5
6
7
8
9
import numpy as np

class KVCache:
def __init__(self, max_size: int = 100000):
self.cache: dict[str, tuple[np.ndarray, float]] = {}
self.max_size = max_size

def get_cached_value(self, board_hash: str):
return self.cache.get(board_hash)

Results

Once parallel MCTS, batched evaluations, and the position cache were in place, the numbers improved across the board:

  • Self-play got 3.2× faster.
  • Training was 2.8× faster.
  • Memory usage dropped by 45%.

We also tried path consistency optimization, which forces value predictions to stay consistent along search paths:

\[L_{PC} = \|f_v - \bar{f}_v\|^2\]

It made the network learn faster, although I’m still not entirely sure why it helps as much as it does.

What I’d say to anyone trying this

Implementing research papers is harder than it looks. The papers make everything sound clean because they have to fit in eight pages. The actual work is in the parts that don’t get written down: how you batch, how you cache, how you handle the long tail of edge cases that come up in self-play.

Most of the speed gap between “the paper says” and “what we measured” came from those details. Some of them really are obvious to the authors, and some of them are obvious only after you’ve spent a week debugging your own implementation.

The code is on GitHub if you want to look through it.