io-chess
UCI chess engine
Loading...
Searching...
No Matches
loss Namespace Reference

Functions

 base_loss_components (pred_wdl, target_wdl, wdl_weight=1.0)
 base_loss (pred_wdl, target_wdl, wdl_weight=1.0)
 killer_loss (pred_wdl, target_wdl, wdl_weight=1.0)
 survivor_loss (pred_wdl, target_wdl, wdl_weight=1.0)
 gater_loss (pred_gates, target_gates)
 moe_loss (pred_wdl, target_wdl, expert_type='BASE')

Variables

int B = 8
 pred_wdl = F.softmax(torch.randn(B, 3), dim=1)
 target_wdl = F.softmax(torch.randn(B, 3), dim=1)
 loss = moe_loss(pred_wdl, target_wdl, expert_type)
 pred_gates = torch.sigmoid(torch.randn(B, 2))
 target_gates = torch.rand(B, 2)
 g_loss = gater_loss(pred_gates, target_gates)

Detailed Description

@file loss.py
@brief Loss functions for training the chess neural network.

Provides functions to compute custom loss metrics combining Win-Draw-Loss (WDL)
probabilities and regularization penalties.

Function Documentation

◆ base_loss()

base_loss ( pred_wdl,
target_wdl,
wdl_weight = 1.0 )
Here is the call graph for this function:
Here is the caller graph for this function:

◆ base_loss_components()

base_loss_components ( pred_wdl,
target_wdl,
wdl_weight = 1.0 )
Standard loss for base experts (Tactical, Strategic, Major End, Minor End).

WDL is the primary objective using Brier-style MSE.

Args:
    pred_wdl: [B, 3] predicted WDL probabilities (from softmax)
    target_wdl: [B, 3] target WDL probabilities
    
Returns:
    (total_loss, weighted_wdl_loss)
Here is the caller graph for this function:

◆ gater_loss()

gater_loss ( pred_gates,
target_gates )
Distillation loss for SmartGater.

Args:
    pred_gates: [B, 2] predicted [survivor, killer] gates
    target_gates: [B, 2] teacher [survivor, killer] gates from C++
    
Returns:
    Scalar MSE loss

◆ killer_loss()

killer_loss ( pred_wdl,
target_wdl,
wdl_weight = 1.0 )
Zoomed loss for Killer expert (winning positions).

KILLER sees positions where STM is winning (eval > 150cp).
Since they will be combined with base experts via residual addition,
the killer expert learns to predict a CORRECTION that emphasizes win certainty.

Zoomed WDL: Maps WinProb [0.60, 1.0] -> [0.0, 1.0]
This creates stronger gradients in winning positions where base experts plateau.

The output should be interpreted as:
- High zoomed_win = "base expert underestimates win probability, add more"
- Low zoomed_win = "base expert is already confident enough"
Here is the caller graph for this function:

◆ moe_loss()

moe_loss ( pred_wdl,
target_wdl,
expert_type = 'BASE' )
Unified loss function with expert type selection.

Args:
    pred_wdl: [B, 3]
    target_wdl: [B, 3]
    expert_type: 'BASE', 'KILLER', or 'SURVIVOR'
Here is the call graph for this function:

◆ survivor_loss()

survivor_loss ( pred_wdl,
target_wdl,
wdl_weight = 1.0 )
Zoomed loss for Survivor expert (losing positions).

SURVIVOR sees positions where STM is losing (eval < -150cp).
The goal is to find defensive resources - draws or prolonging the game.

Zoomed WDL: Maps LossProb [0.60, 1.0] -> [0.0, 1.0]
Focus: Maximize draw probability, not just predict loss correctly.

The output emphasizes:
- Draw probability as the primary "hope" signal
- Lower zoomed_loss = more swindle potential
Here is the caller graph for this function:

Variable Documentation

◆ B

int loss.B = 8

◆ g_loss

◆ loss

loss.loss = moe_loss(pred_wdl, target_wdl, expert_type)

◆ pred_gates

loss.pred_gates = torch.sigmoid(torch.randn(B, 2))

◆ pred_wdl

loss.pred_wdl = F.softmax(torch.randn(B, 3), dim=1)

◆ target_gates

loss.target_gates = torch.rand(B, 2)

◆ target_wdl

loss.target_wdl = F.softmax(torch.randn(B, 3), dim=1)