
During my time at Cogito NTNU, I dove deep into implementing AlphaZero from scratch. What started as a fun project turned into a fascinating exploration of deep reinforcement learning, parallel computing, and performance optimization. Let me share what I learned along the way.
The Core Algorithm
At its heart, AlphaZero combines Monte Carlo Tree Search (MCTS) with a neural network that predicts both move probabilities and position values. The training process involves an elegant loop:
- Self-play to generate training data
- Neural network training
- Model evaluation
The policy-value network outputs two key components:
- A policy vector $\pi$ representing move probabilities
- A value estimate $v\in[-1,1]$ for the current position
Here’s a simplified version of our network architecture:
1 | import torch |
Figure 1: A visualization of our MCTS implementation showing a Tic Tac Toe board with red and white circles (and one empty black position) on the left, and its corresponding search tree on the right displaying visit counts and move numbers at each node, starting from 500 visits at the root.
Training Optimization
The loss function combines policy and value objectives:
$$L = (z - v)^2 - \pi^T \log p + c|\theta|^2$$
where:
- $z$ is the game outcome
- $v$ is the predicted value
- $\pi$ is the MCTS policy
- $p$ is the predicted policy
- $c|\theta|^2$ is L2 regularization
Parallelization Tricks
One thing that wasn’t immediately obvious was how much performance we could gain through proper parallelization. Here’s what worked well for us:
1 | import torch.multiprocessing as mp |
Memory Management
A critical optimization was implementing smart key-value caching for MCTS. This reduced memory usage and improved inference speed:
1 | import numpy as np |
Performance Results
After implementing these optimizations, we saw significant improvements:
- 3.2x speedup in self-play data generation
- 2.8x faster neural network training
- 45% reduction in memory usage
The path consistency optimization was particularly effective, improving the learning rate by enforcing value consistency along search paths:
$$L_{PC} = |f_v - \bar{f}_v|^2$$
where $f_v$ represents feature maps and $\bar{f}_v$ is their element-wise average.
Lessons Learned
Looking back, the project taught me a lot about the practical challenges of implementing research papers. The theory might look clean, but getting good performance requires careful attention to systems-level optimizations and smart engineering choices.
We’ve open-sourced our implementation - feel free to check it out here and let us know if you find ways to make it even faster. There’s always room for improvement!