io-chess
UCI chess engine
Loading...
Searching...
No Matches
MoECacheModel.hpp File Reference

Factorized Mixture of Experts (MoE) neural network architecture and inference components. More...

#include <algorithm>
#include <array>
#include <atomic>
#include <cassert>
#include <chrono>
#include <cmath>
#include <condition_variable>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <functional>
#include <iomanip>
#include <iostream>
#include <limits>
#include <memory>
#include <mutex>
#include <queue>
#include <random>
#include <string>
#include <thread>
#include <type_traits>
#include <utility>
#include <vector>
#include <chess.hpp>
#include "FactorizedFeatureExtractor.hpp"
#include "PersistentThreadPool.hpp"
Include dependency graph for MoECacheModel.hpp:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Classes

struct  BenchConfig
struct  BenchResult
struct  BranchLayer
struct  Branch
struct  Expert
struct  SharedMoEWeights
 Contains the globally shared, read-only weights for the Factorized MoE network. More...
struct  MoEDoubleAccumulator
 Thread-local state for incremental, lightning-fast neural network inference. More...
struct  MoEDoubleAccumulator::PhaseProfile

Macros

#define FORCE_VECTORIZE
#define HOT_RESTRICT   __restrict__

Typedefs

using Clock = std::chrono::high_resolution_clock

Enumerations

enum class  ExpertPoolMode { Flat , Gap , Pool2x2Avg , Pool2x2Max }
 Defines how spatial features are pooled before entering the fully-connected expert networks. More...

Functions

static double us (Clock::time_point a, Clock::time_point b)
static bool is_slider_branch_rich (int branch_idx)
 Determines if a specific branch corresponds to a slider piece type (Bishop, Rook, Queen).
static int active_channels_for_branch (int branch_idx)
 Returns the number of active input feature channels for a given branch.
static bool is_slider_dirty (const Board &board, const Board &old_board, Color side, PieceType pt)
 Checks if a slider piece type (Bishop, Rook, Queen) has become "dirty".
static void mark_head (std::array< uint8_t, 12 > &m, Color s, PieceType pt)
 Marks a specific piece type for a specific color as "dirty" in the dirty mask.
static std::array< uint8_t, 12 > build_dirty_mask (const Board &old_board, const Board &new_board, const Move &mv)
 Analyzes a chess move to determine exactly which of the 12 spatial branches need recomputing.
static bool branch_planes_changed (const FactorizedInput &a, const FactorizedInput &b, int branch_idx)
 Performs a fast byte-level comparison to see if two spatial branch inputs are identical.
static void simd_add_scaled (float *HOT_RESTRICT dst, const float *HOT_RESTRICT src, float scale, int n)
 SIMD-optimized fused multiply-add (FMA) for adding a scaled vector to a destination vector. dst[i] += src[i] * scale.
template<typename T>
static T * assume_aligned_32 (T *ptr)
 Hints the compiler that a pointer is aligned to a 32-byte boundary (for AVX operations).
template<typename T>
static const T * assume_aligned_32 (const T *ptr)
static void conv3x3_accumulate_plane (float *HOT_RESTRICT out_plane, const float *HOT_RESTRICT in_plane, const float *HOT_RESTRICT wk)
 Computes a full 3x3 2D convolution for a single input channel and accumulates it into out_plane.
static void conv3x3_single_out_relu (const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, float b, float *HOT_RESTRICT out, int ic)
 Computes a 3x3 convolution across multiple input channels (ic) to produce a single output channel, followed by ReLU.
static void conv3x3_relu (const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, const float *HOT_RESTRICT b, float *HOT_RESTRICT out, int ic, int oc)
 Computes a full 3x3 convolution from ic input channels to oc output channels, followed by ReLU.
static void conv3x3_relu_bd16_dispatch (const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, const float *HOT_RESTRICT b, float *HOT_RESTRICT out, int ic)
static void depthwise_conv3x3_relu (const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, const float *HOT_RESTRICT b, float *HOT_RESTRICT out, int channels)
 Computes a depthwise 3x3 convolution followed by a ReLU activation.
static void conv1x1_relu (const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, const float *HOT_RESTRICT b, float *HOT_RESTRICT out, int ic, int oc)
 Computes a 1x1 convolution (pointwise convolution) across the spatial grid, followed by ReLU.
static const char * expert_pool_mode_name (ExpertPoolMode mode)
static int pool2x2_region_from_sq (int sq)
static void pool2x2_region_base (int region, int &r0, int &c0)

Variables

static constexpr int kDefaultBranchDim = 16
static constexpr int kDefaultMixerOut = 64
static constexpr int kDefaultBypass = 12
static constexpr int kDefaultGlobals = 21
static constexpr int kDefaultExperts = 4
static constexpr int kDefaultExpertBottleneck = 32
static constexpr int kDefaultExpertHidden = 128
static constexpr int NET_BRANCH_DIM = 16
static constexpr int NET_MIXER_OUT = 64
static constexpr int NET_BYPASS = 12
static constexpr int NET_GLOBALS = 21
static constexpr int NET_EXPERTS = 4
static constexpr int NET_EXPERT_BOTTLENECK = 32
static constexpr int NET_EXPERT_HIDDEN = 128
static constexpr int kInputBypassPlanes = 12
static constexpr int kInputGlobals = 32
static constexpr int kMaxBranchDim = NET_BRANCH_DIM
static constexpr int kMaxMixerOut = NET_MIXER_OUT
static constexpr int kMaxBypass = NET_BYPASS
static constexpr int kMaxGlobals = NET_GLOBALS
static constexpr int kMaxExperts = NET_EXPERTS
static constexpr int kMaxExpertBottleneck = NET_EXPERT_BOTTLENECK
static constexpr int kMaxExpertHidden = NET_EXPERT_HIDDEN
static constexpr int kMixerOcTile = 8
static constexpr int kPool2x2Regions = 16

Detailed Description

Factorized Mixture of Experts (MoE) neural network architecture and inference components.

This file implements the core inference structures for the natively-compiled Factorized MoE network used for position evaluation. The architecture processes a custom 12-branch spatial feature set extracted via FactorizedFeatureExtractor, applies a mixer layer, routes the evaluation to specialized expert networks, and computes the final Win/Draw/Loss probabilities.

Architecture Overview

  • Branches: 12 independent 3x3 convolutional feature branches (e.g., White Pawns, Black Knights, etc.).
  • Mixer: A fully-connected layer that combines the outputs of all 12 branches and global features.
  • Router: A shallow network that selects the most appropriate "Expert" sub-networks for the given position (e.g., Endgame vs Opening).
  • Experts: Specialized fully-connected networks (Bottleneck -> Hidden -> WDL Output).

Double Accumulator (Incremental Inference)

For maximum performance, this file uses a double accumulator pattern (MoEDoubleAccumulator):

  1. Base Accumulator: Stores the state of the network for the current position.
  2. Dirty Accumulator: When a move is made, instead of running the entire network, we only compute the delta (difference) for the specific pieces that moved or were captured (dirty_branches). This reduces the evaluation time from ~20 microseconds (full forward pass) to <1 microsecond (incremental update).
Note
This file contains both engine evaluation structures (MoEDoubleAccumulator) and full-batch multi-threaded evaluation structures (SharedMoEWeights::forward) which are strictly retained for the standalone benchmark tools in src/tools/.

Macro Definition Documentation

◆ FORCE_VECTORIZE

#define FORCE_VECTORIZE

◆ HOT_RESTRICT

#define HOT_RESTRICT   __restrict__

Typedef Documentation

◆ Clock

using Clock = std::chrono::high_resolution_clock

Enumeration Type Documentation

◆ ExpertPoolMode

enum class ExpertPoolMode
strong

Defines how spatial features are pooled before entering the fully-connected expert networks.

Enumerator
Flat 
Gap 
Pool2x2Avg 
Pool2x2Max 

Function Documentation

◆ active_channels_for_branch()

int active_channels_for_branch ( int branch_idx)
inlinestatic

Returns the number of active input feature channels for a given branch.

Parameters
branch_idxThe index of the branch.
Returns
5 channels for sliders (includes X-ray features), 4 channels for non-sliders.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ assume_aligned_32() [1/2]

template<typename T>
const T * assume_aligned_32 ( const T * ptr)
inlinestatic

◆ assume_aligned_32() [2/2]

template<typename T>
T * assume_aligned_32 ( T * ptr)
inlinestatic

Hints the compiler that a pointer is aligned to a 32-byte boundary (for AVX operations).

Here is the caller graph for this function:

◆ branch_planes_changed()

bool branch_planes_changed ( const FactorizedInput & a,
const FactorizedInput & b,
int branch_idx )
inlinestatic

Performs a fast byte-level comparison to see if two spatial branch inputs are identical.

Parameters
aThe old factorized input.
bThe new factorized input.
branch_idxThe index of the branch to check.
Returns
True if the inputs differ, false if they are identical.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_dirty_mask()

std::array< uint8_t, 12 > build_dirty_mask ( const Board & old_board,
const Board & new_board,
const Move & mv )
static

Analyzes a chess move to determine exactly which of the 12 spatial branches need recomputing.

This is the core of the incremental update logic. It prevents full network re-evaluations by figuring out which specific piece types (e.g., White Pawns, Black Knights) had their board representation changed by the move, including complex cases like discoveries, castling, and en passant.

Parameters
old_boardThe board state before the move.
new_boardThe board state after the move.
mvThe move that was just played.
Returns
An array of 12 bytes acting as boolean flags (1 = dirty, 0 = clean) for each branch.
Here is the call graph for this function:

◆ conv1x1_relu()

void conv1x1_relu ( const float *HOT_RESTRICT in,
const float *HOT_RESTRICT w,
const float *HOT_RESTRICT b,
float *HOT_RESTRICT out,
int ic,
int oc )
static

Computes a 1x1 convolution (pointwise convolution) across the spatial grid, followed by ReLU.

This operation mixes features across channels for every spatial square independently.

Parameters
inFlattened input tensor [ic][64].
wFlattened weights [oc][ic].
bFlattened biases [oc].
outFlattened output tensor [oc][64].
icNumber of input channels.
ocNumber of output channels.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ conv3x3_accumulate_plane()

void conv3x3_accumulate_plane ( float *HOT_RESTRICT out_plane,
const float *HOT_RESTRICT in_plane,
const float *HOT_RESTRICT wk )
inlinestatic

Computes a full 3x3 2D convolution for a single input channel and accumulates it into out_plane.

This avoids boundary checks by breaking the 8x8 chessboard into regions (Center, North/South edges, Corners).

Parameters
out_planeThe destination 64-element array.
in_planeThe source 64-element array representing the spatial features.
wkA 9-element array representing the 3x3 convolution weights.
Here is the caller graph for this function:

◆ conv3x3_relu()

void conv3x3_relu ( const float *HOT_RESTRICT in,
const float *HOT_RESTRICT w,
const float *HOT_RESTRICT b,
float *HOT_RESTRICT out,
int ic,
int oc )
static

Computes a full 3x3 convolution from ic input channels to oc output channels, followed by ReLU.

Parameters
inFlattened input tensor [ic][64].
wFlattened weights [oc][ic][9].
bFlattened biases [oc].
outFlattened output tensor [oc][64].
icNumber of input channels.
ocNumber of output channels.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ conv3x3_relu_bd16_dispatch()

void conv3x3_relu_bd16_dispatch ( const float *HOT_RESTRICT in,
const float *HOT_RESTRICT w,
const float *HOT_RESTRICT b,
float *HOT_RESTRICT out,
int ic )
inlinestatic
Here is the call graph for this function:
Here is the caller graph for this function:

◆ conv3x3_single_out_relu()

void conv3x3_single_out_relu ( const float *HOT_RESTRICT in,
const float *HOT_RESTRICT w,
float b,
float *HOT_RESTRICT out,
int ic )
inlinestatic

Computes a 3x3 convolution across multiple input channels (ic) to produce a single output channel, followed by ReLU.

Parameters
inThe flattened input tensor [ic][64].
wThe flattened weights [ic][9].
bThe bias scalar.
outThe destination 64-element array for the output channel.
icNumber of input channels.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ depthwise_conv3x3_relu()

void depthwise_conv3x3_relu ( const float *HOT_RESTRICT in,
const float *HOT_RESTRICT w,
const float *HOT_RESTRICT b,
float *HOT_RESTRICT out,
int channels )
inlinestatic

Computes a depthwise 3x3 convolution followed by a ReLU activation.

In a depthwise convolution, each input channel is convolved with its own set of spatial weights, producing exactly one output channel per input channel, without mixing information across channels.

Parameters
inFlattened input tensor [channels][64].
wFlattened depthwise weights [channels][9].
bFlattened biases [channels].
outFlattened output tensor [channels][64].
channelsNumber of channels (both input and output).
Here is the call graph for this function:
Here is the caller graph for this function:

◆ expert_pool_mode_name()

const char * expert_pool_mode_name ( ExpertPoolMode mode)
inlinestatic

◆ is_slider_branch_rich()

bool is_slider_branch_rich ( int branch_idx)
inlinestatic

Determines if a specific branch corresponds to a slider piece type (Bishop, Rook, Queen).

Parameters
branch_idxThe index of the branch (0-11).
Returns
True if the branch represents a slider piece, false otherwise.
Here is the caller graph for this function:

◆ is_slider_dirty()

bool is_slider_dirty ( const Board & board,
const Board & old_board,
Color side,
PieceType pt )
static

Checks if a slider piece type (Bishop, Rook, Queen) has become "dirty".

A slider is dirty if any piece of that type moved, or if the board occupancy changed in a way that intersects with any of the slider's rays (blocking or unblocking its attack path).

Parameters
boardThe new board state.
old_boardThe previous board state.
sideThe color of the slider pieces to check.
ptThe piece type (BISHOP, ROOK, or QUEEN).
Returns
True if the slider's features need to be recomputed, false otherwise.
Here is the caller graph for this function:

◆ mark_head()

void mark_head ( std::array< uint8_t, 12 > & m,
Color s,
PieceType pt )
static

Marks a specific piece type for a specific color as "dirty" in the dirty mask.

Parameters
mThe 12-element boolean array mapping to the 12 spatial branches.
sThe color of the piece.
ptThe type of the piece.
Here is the caller graph for this function:

◆ pool2x2_region_base()

void pool2x2_region_base ( int region,
int & r0,
int & c0 )
inlinestatic
Here is the caller graph for this function:

◆ pool2x2_region_from_sq()

int pool2x2_region_from_sq ( int sq)
inlinestatic
Here is the caller graph for this function:

◆ simd_add_scaled()

void simd_add_scaled ( float *HOT_RESTRICT dst,
const float *HOT_RESTRICT src,
float scale,
int n )
inlinestatic

SIMD-optimized fused multiply-add (FMA) for adding a scaled vector to a destination vector. dst[i] += src[i] * scale.

Parameters
dstThe accumulator array.
srcThe source array to scale and add.
scaleThe scalar multiplier.
nNumber of elements.
Here is the caller graph for this function:

◆ us()

double us ( Clock::time_point a,
Clock::time_point b )
inlinestatic
Here is the caller graph for this function:

Variable Documentation

◆ kDefaultBranchDim

int kDefaultBranchDim = 16
staticconstexpr

◆ kDefaultBypass

int kDefaultBypass = 12
staticconstexpr

◆ kDefaultExpertBottleneck

int kDefaultExpertBottleneck = 32
staticconstexpr

◆ kDefaultExpertHidden

int kDefaultExpertHidden = 128
staticconstexpr

◆ kDefaultExperts

int kDefaultExperts = 4
staticconstexpr

◆ kDefaultGlobals

int kDefaultGlobals = 21
staticconstexpr

◆ kDefaultMixerOut

int kDefaultMixerOut = 64
staticconstexpr

◆ kInputBypassPlanes

int kInputBypassPlanes = 12
staticconstexpr

◆ kInputGlobals

int kInputGlobals = 32
staticconstexpr

◆ kMaxBranchDim

int kMaxBranchDim = NET_BRANCH_DIM
staticconstexpr

◆ kMaxBypass

int kMaxBypass = NET_BYPASS
staticconstexpr

◆ kMaxExpertBottleneck

int kMaxExpertBottleneck = NET_EXPERT_BOTTLENECK
staticconstexpr

◆ kMaxExpertHidden

int kMaxExpertHidden = NET_EXPERT_HIDDEN
staticconstexpr

◆ kMaxExperts

int kMaxExperts = NET_EXPERTS
staticconstexpr

◆ kMaxGlobals

int kMaxGlobals = NET_GLOBALS
staticconstexpr

◆ kMaxMixerOut

int kMaxMixerOut = NET_MIXER_OUT
staticconstexpr

◆ kMixerOcTile

int kMixerOcTile = 8
staticconstexpr

◆ kPool2x2Regions

int kPool2x2Regions = 16
staticconstexpr

◆ NET_BRANCH_DIM

int NET_BRANCH_DIM = 16
staticconstexpr

◆ NET_BYPASS

int NET_BYPASS = 12
staticconstexpr

◆ NET_EXPERT_BOTTLENECK

int NET_EXPERT_BOTTLENECK = 32
staticconstexpr

◆ NET_EXPERT_HIDDEN

int NET_EXPERT_HIDDEN = 128
staticconstexpr

◆ NET_EXPERTS

int NET_EXPERTS = 4
staticconstexpr

◆ NET_GLOBALS

int NET_GLOBALS = 21
staticconstexpr

◆ NET_MIXER_OUT

int NET_MIXER_OUT = 64
staticconstexpr