io-chess
UCI chess engine
Loading...
Searching...
No Matches
Training Infrastructure

Dataset Loading

The dataset.py module provides two specialised sampler classes designed for large-scale training on binary datasets that do not fit in RAM:

  • ChunkedRandomSampler: Loads data in large contiguous chunks from memory-mapped files, shuffles positions within each chunk, and yields mini-batches. RAM usage is proportional to the chunk size (typically 1-4 GB), not the total dataset size.
  • ExpertChunkedSampler: Extends the chunked sampler to filter positions by algorithmic expert label, used during Phase 2 to train each expert on its assigned subset.

Both samplers support multi-worker data loading via PyTorch's DataLoader with num_workers > 0.

Loss Functions

The loss.py module defines the training objective:

Component Description
WDL Brier-style MSE Mean squared error between predicted and target Win/Draw/Loss probabilities
Note
The network is trained directly to predict Win/Draw/Loss probabilities (which are converted to centipawns internally by the engine during search). Because routing is handled algorithmically during preprocessing, there is no need for a routing or load-balancing loss component during training.

Weight Export

After training, export.py converts the PyTorch state dict into the engine's flat binary weight format. The export script:

  1. Loads the checkpoint and extracts all parameter tensors.
  2. Quantises weights where appropriate.
  3. Writes a flat binary file with a simple header + raw weight data.

The engine loads this file at startup and constructs the native C++ inference graph from the serialised weights.

python export.py --checkpoint checkpoints/model_finetune_phase4_final.pt --output weights.bin

Experiment Tracking

Training supports optional Weights & Biases integration for logging loss curves, learning rate schedules, expert utilisation histograms, and gradient norms. Enable it with the --wandb flag:

python train.py --phase 1 --wandb --project io-chess --data_dirs /data/chess

Tracked metrics include:

Metric Description
train/loss Total training loss
train/wdl_loss WDL cross-entropy component
val/loss Validation loss (computed every N steps)
lr Current learning rate