|
io-chess
UCI chess engine
|
Namespaces | |
| namespace | loss |
Functions | |
| loss.base_loss_components (pred_wdl, target_wdl, wdl_weight=1.0) | |
| loss.base_loss (pred_wdl, target_wdl, wdl_weight=1.0) | |
| loss.killer_loss (pred_wdl, target_wdl, wdl_weight=1.0) | |
| loss.survivor_loss (pred_wdl, target_wdl, wdl_weight=1.0) | |
| loss.gater_loss (pred_gates, target_gates) | |
| loss.moe_loss (pred_wdl, target_wdl, expert_type='BASE') | |
Variables | |
| int | loss.B = 8 |
| loss.pred_wdl = F.softmax(torch.randn(B, 3), dim=1) | |
| loss.target_wdl = F.softmax(torch.randn(B, 3), dim=1) | |
| loss.loss = moe_loss(pred_wdl, target_wdl, expert_type) | |
| loss.pred_gates = torch.sigmoid(torch.randn(B, 2)) | |
| loss.target_gates = torch.rand(B, 2) | |
| loss.g_loss = gater_loss(pred_gates, target_gates) | |