Implementing Fast AlphaZero

During my time at Cogito NTNU, I dove deep into implementing AlphaZero from scratch. What started as a fun project turned into a fascinating exploration of deep reinforcement learning, parallel computing, and performance optimization. Let me share what I learned along the way.

The Core Algorithm

At its heart, AlphaZero combines Monte Carlo Tree Search (MCTS) with a neural network that predicts both move probabilities and position values. The training process involves an elegant loop:

  1. Self-play to generate training data
  2. Neural network training
  3. Model evaluation

The policy-value network outputs two key components:

  • A policy vector $\pi$ representing move probabilities
  • A value estimate $v\in[-1,1]$ for the current position

Here’s a simplified version of our network architecture:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn

class PolicyValueNet(nn.Module):
"""
Neural network that outputs both policy and value predictions.

Args:
board_size: Size of the game board (height, width)
action_size: Number of possible actions
num_channels: Number of channels in convolutional layers
"""
def __init__(self, board_size: tuple[int, int], action_size: int, num_channels: int = 256):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, num_channels, 3, padding=1),
nn.BatchNorm2d(num_channels),
nn.ReLU(),
nn.Conv2d(num_channels, num_channels, 3, padding=1),
nn.BatchNorm2d(num_channels),
nn.ReLU()
)

board_size_flat = board_size[0] * board_size[1]
self.policy_head = nn.Sequential(
nn.Linear(num_channels * board_size_flat, action_size),
nn.LogSoftmax(dim=1)
)

self.value_head = nn.Sequential(
nn.Linear(num_channels * board_size_flat, 1),
nn.Tanh()
)


Figure 1: A visualization of our MCTS implementation showing a Tic Tac Toe board with red and white circles (and one empty black position) on the left, and its corresponding search tree on the right displaying visit counts and move numbers at each node, starting from 500 visits at the root.

Training Optimization

The loss function combines policy and value objectives:
$$L = (z - v)^2 - \pi^T \log p + c|\theta|^2$$
where:

  • $z$ is the game outcome
  • $v$ is the predicted value
  • $\pi$ is the MCTS policy
  • $p$ is the predicted policy
  • $c|\theta|^2$ is L2 regularization

Parallelization Tricks

One thing that wasn’t immediately obvious was how much performance we could gain through proper parallelization. Here’s what worked well for us:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch.multiprocessing as mp

class ParallelMCTS:
"""
Parallel MCTS implementation using multiple processes.

Args:
num_processes: Number of parallel processes to use
batch_size: Size of evaluation batches
num_simulations: Number of MCTS simulations per move
"""
def __init__(self, num_processes: int = 8, batch_size: int = 32,
num_simulations: int = 800):
self.pool = mp.Pool(num_processes)
self.batch_size = batch_size
self.num_simulations = num_simulations

def run_batch_evaluation(self, positions: list[torch.Tensor]) -> list[tuple[torch.Tensor, float]]:
"""Evaluates a batch of positions using the neural network."""
with torch.no_grad():
policies, values = [], []
for i in range(0, len(positions), self.batch_size):
batch = positions[i:i + self.batch_size]
batch_policies, batch_values = self.network(torch.stack(batch))
policies.extend(batch_policies)
values.extend(batch_values)
return list(zip(policies, values))

Memory Management

A critical optimization was implementing smart key-value caching for MCTS. This reduced memory usage and improved inference speed:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

class KVCache:
"""
Key-value cache for MCTS node evaluations.

Args:
max_size: Maximum number of positions to cache
"""
def __init__(self, max_size: int = 100000):
self.cache: dict[str, tuple[np.ndarray, float]] = {}
self.max_size = max_size

def get_cached_value(self, board_hash: str) -> tuple[np.ndarray, float] | None:
return self.cache.get(board_hash)

Performance Results

After implementing these optimizations, we saw significant improvements:

  • 3.2x speedup in self-play data generation
  • 2.8x faster neural network training
  • 45% reduction in memory usage

The path consistency optimization was particularly effective, improving the learning rate by enforcing value consistency along search paths:

$$L_{PC} = |f_v - \bar{f}_v|^2$$

where $f_v$ represents feature maps and $\bar{f}_v$ is their element-wise average.

Lessons Learned

Looking back, the project taught me a lot about the practical challenges of implementing research papers. The theory might look clean, but getting good performance requires careful attention to systems-level optimizations and smart engineering choices.

We’ve open-sourced our implementation - feel free to check it out here and let us know if you find ways to make it even faster. There’s always room for improvement!