
While working at Cogito NTNU, I took on the challenge of implementing AlphaZero from scratch. What I thought would be a straightforward paper implementation turned into a multi-month journey of debugging and optimization. Along the way, I learned firsthand why research papers tend to gloss over the nitty-gritty implementation details.
How AlphaZero Actually Works
AlphaZero combines Monte Carlo Tree Search with a neural network that tries to predict two things: which moves are good, and who’s winning. The training loop is pretty straightforward:
- Self play to generate training data
- Train the neural network on that data
- Test if the new network is better than the old one
The neural network outputs a policy vector π (move probabilities) and a value estimate $v\in [-1,1]$ for how good the current position is.
Here’s what our network looked like:
1 | import torch |
Figure 1: A visualization of our MCTS implementation showing a Tic Tac Toe board with red and white circles (and one empty black position) on the left, and its corresponding search tree on the right displaying visit counts and move numbers at each node, starting from 500 visits at the root.
The Loss Function
The network learns by minimizing this loss:
$$L = (z - v)^2 - \pi^T \log p + c|\theta|^2$$
where $z$ is the actual game outcome, $v$ is what the network predicted, $\pi$ is what MCTS thinks the best moves are, $p$ is what the network thinks, and that last term just prevents overfitting.
Making It Actually Fast
The biggest lesson was that the algorithm is only half the battle. Getting good performance required a lot of engineering tricks that nobody talks about in papers.
Parallel Everything
Running MCTS simulations in parallel was huge for performance:
1 | import torch.multiprocessing as mp |
Smart Caching
We also added caching to avoid evaluating the same positions over and over:
1 | import numpy as np |
What We Actually Achieved
After all the optimizations, we got some pretty solid improvements:
- Self play got 3.2x faster
- Network training sped up by 2.8x
- Memory usage dropped by 45%
We also tried this thing called path consistency optimization, which forces the value predictions to be consistent along search paths:
$$L_{PC} = |f_v - \bar{f}_v|^2$$
It helped the network learn faster, though honestly I’m still not 100% sure why it works so well.
What I Actually Learned
Building this taught me that implementing research papers is way harder than it looks. The papers make everything sound clean and simple, but getting something that actually runs fast requires tons of boring engineering work that never gets mentioned.
The biggest time sink was debugging why our implementation was so much slower than reported results. Turns out a lot of the speedup comes from implementation details that are just assumed knowledge.
We put the code up on GitHub if you want to check it out. Fair warning though: getting it to run well takes some patience.