io-chess
UCI chess engine
Loading...
Searching...
No Matches
ExpertRouter.hpp
Go to the documentation of this file.
1
5#pragma once
8#include <algorithm>
9#include <array>
10#include <cmath>
11#include <cstdint>
12
31public:
32 static constexpr int NUM_EXPERTS = 6;
33 static constexpr int NUM_BASE = 4; // Base experts (softmax group)
34 static constexpr int NUM_AUX = 2; // Auxiliary experts (independent gates)
35
36 // Expert indices
37 enum Expert : int {
38 // Base Group (indices 0-3, softmax)
43 // Aux Group (indices 4-5, independent)
46 };
47
48 // Output structure for routing (Residual MoE)
49 // Binary format: 6 floats = [base0, base1, base2, base3, survivor, killer]
50 // - base[0-3]: sum to 1.0 (softmax normalized)
51 // - survivor/killer: independent 0.0 to 1.0 (gate activations)
52 struct alignas(32) ExpertWeights {
53 float weights[NUM_EXPERTS]; // Final weights for output
54 float raw_scores[NUM_EXPERTS]; // Pre-normalization scores (debugging)
55 };
56
57 // =========================================================================
58 // Thresholds (tunable)
59 // =========================================================================
60 static constexpr float PHASE_ENDGAME_THRESHOLD = 0.65f;
61
62 // CCT thresholds
63 static constexpr float CCT_TACTICAL_THRESHOLD = 14.0f;
64 static constexpr float CCT_STRATEGIC_THRESHOLD = 6.0f;
65
66 // Aux gate thresholds (in CENTIPAWNS)
67 static constexpr float KILLER_CP_THRESHOLD =
68 150.0f; // Activate when winning 1.5+ pawns
69 static constexpr float SURVIVOR_CP_THRESHOLD =
70 150.0f; // Activate when losing 1.5+ pawns
71 static constexpr float AUX_RAMP_RANGE = 500.0f; // CP range for 0->1 ramp
72
73 // CCT weights
74 static constexpr float CCT_WEIGHT_CHECK = 10.0f;
75 static constexpr float CCT_WEIGHT_MAJOR_THREAT = 3.0f;
76
77 // Piece values for capture scoring
78 static constexpr float PIECE_VALUE_PAWN = 1.0f;
79 static constexpr float PIECE_VALUE_KNIGHT = 3.0f;
80 static constexpr float PIECE_VALUE_BISHOP = 3.0f;
81 static constexpr float PIECE_VALUE_ROOK = 5.0f;
82 static constexpr float PIECE_VALUE_QUEEN = 9.0f;
83
84 // Base expert score magnitudes (continuous scoring)
85 // SCORE_MAX: maximum score for dominant expert
86 // SCORE_MIN: minimum score to prevent softmax degeneration
87 static constexpr float SCORE_MAX = 5.0f;
88 static constexpr float SCORE_MIN = 0.5f;
89
90 // =========================================================================
91 // Main API
92 // =========================================================================
93
106 static void compute_weights(const ChessInput &input, float eval_cp,
107 ExpertWeights &out) {
108 // Compute derived features
109 float phase = compute_phase(input); // 0.0 = opening, 1.0 = endgame
110 float cct = compute_cct(input); // Higher = more tactical
111 float major_ratio =
112 compute_major_ratio(input); // 0.0 = no majors, 1.0 = all majors
113
114 // ========== GROUP A: Base Experts (Continuous Scores) ==========
115 // Each expert gets a continuous score based on position features.
116 // Softmax will produce natural weights that reflect the position's
117 // character.
118
119 float raw_base[NUM_BASE] = {0.0f};
120
121 // --- Phase-based split: Middlegame vs Endgame ---
122 // Smooth transition around PHASE_ENDGAME_THRESHOLD (0.65)
123 // sigmoid centered at threshold, steepness controls blend zone
124 float endgame_factor = sigmoid((phase - PHASE_ENDGAME_THRESHOLD) * 8.0f);
125 float middlegame_factor = 1.0f - endgame_factor;
126
127 // --- Middlegame: Tactical vs Strategic (CCT-based) ---
128 // Smooth interpolation across entire CCT range
129 float tactical_strength =
132 2.0f);
133
134 raw_base[TACTICAL] = middlegame_factor * tactical_strength * SCORE_MAX;
135 raw_base[STRATEGIC] =
136 middlegame_factor * (1.0f - tactical_strength) * SCORE_MAX;
137
138 // --- Endgame: Major vs Minor (piece type based) ---
139 // major_ratio: 0.0 = pure minor endgame, 1.0 = pure major endgame
140 // Blend smoothly based on actual piece composition
141 raw_base[MAJOR_END] = endgame_factor * major_ratio * SCORE_MAX;
142 raw_base[MINOR_END] = endgame_factor * (1.0f - major_ratio) * SCORE_MAX;
143
144 // --- Ensure minimum score for active experts ---
145 // This prevents any expert from getting exactly 0 (softmax issues)
146 for (int i = 0; i < NUM_BASE; ++i) {
147 raw_base[i] = std::max(raw_base[i], SCORE_MIN);
148 }
149
150 // Softmax normalize base experts (sum to 1.0)
151 softmax(raw_base, out.weights, NUM_BASE);
152
153 // Copy raw scores for debugging / dataset selection
154 for (int i = 0; i < NUM_BASE; ++i) {
155 out.raw_scores[i] = raw_base[i];
156 }
157
158 // ========== GROUP B: Aux Experts (Independent Gates) ==========
159
160 // SURVIVOR (Index 4): Linear ramp when losing
161 if (eval_cp < -SURVIVOR_CP_THRESHOLD) {
162 float deficit = std::abs(eval_cp) - SURVIVOR_CP_THRESHOLD;
163 out.weights[SURVIVOR] = std::clamp(deficit / AUX_RAMP_RANGE, 0.0f, 1.0f);
164 } else {
165 out.weights[SURVIVOR] = 0.0f;
166 }
167 out.raw_scores[SURVIVOR] = out.weights[SURVIVOR]; // Gate value = raw score
168
169 // KILLER (Index 5): Linear ramp when winning
170 if (eval_cp > KILLER_CP_THRESHOLD) {
171 float bonus = eval_cp - KILLER_CP_THRESHOLD;
172 out.weights[KILLER] = std::clamp(bonus / AUX_RAMP_RANGE, 0.0f, 1.0f);
173 } else {
174 out.weights[KILLER] = 0.0f;
175 }
176 out.raw_scores[KILLER] = out.weights[KILLER]; // Gate value = raw score
177 }
178
183 static void compute_weights(const FactorizedInput &input, float eval_cp,
184 ExpertWeights &out) {
185 float phase = compute_phase(input);
186 float cct = compute_cct(input);
187 float major_ratio = compute_major_ratio(input);
188
189 float raw_base[NUM_BASE] = {0.0f};
190
191 float endgame_factor = sigmoid((phase - PHASE_ENDGAME_THRESHOLD) * 8.0f);
192 float middlegame_factor = 1.0f - endgame_factor;
193
194 float tactical_strength =
197 4.0f -
198 2.0f);
199
200 raw_base[TACTICAL] = middlegame_factor * tactical_strength * SCORE_MAX;
201 raw_base[STRATEGIC] =
202 middlegame_factor * (1.0f - tactical_strength) * SCORE_MAX;
203 raw_base[MAJOR_END] = endgame_factor * major_ratio * SCORE_MAX;
204 raw_base[MINOR_END] = endgame_factor * (1.0f - major_ratio) * SCORE_MAX;
205
206 for (int i = 0; i < NUM_BASE; ++i) {
207 raw_base[i] = std::max(raw_base[i], SCORE_MIN);
208 }
209
210 softmax(raw_base, out.weights, NUM_BASE);
211
212 for (int i = 0; i < NUM_BASE; ++i) {
213 out.raw_scores[i] = raw_base[i];
214 }
215
216 if (eval_cp < -SURVIVOR_CP_THRESHOLD) {
217 float deficit = std::abs(eval_cp) - SURVIVOR_CP_THRESHOLD;
218 out.weights[SURVIVOR] = std::clamp(deficit / AUX_RAMP_RANGE, 0.0f, 1.0f);
219 } else {
220 out.weights[SURVIVOR] = 0.0f;
221 }
223
224 if (eval_cp > KILLER_CP_THRESHOLD) {
225 float bonus = eval_cp - KILLER_CP_THRESHOLD;
226 out.weights[KILLER] = std::clamp(bonus / AUX_RAMP_RANGE, 0.0f, 1.0f);
227 } else {
228 out.weights[KILLER] = 0.0f;
229 }
230 out.raw_scores[KILLER] = out.weights[KILLER];
231 }
232
243 static void get_top_k(const ExpertWeights &weights, int k, int *indices,
244 float *probs) {
245 // Create index array
246 int idx[NUM_EXPERTS] = {0, 1, 2, 3, 4, 5};
247
248 // Partial sort to get top-k
249 std::partial_sort(idx, idx + k, idx + NUM_EXPERTS,
250 [&weights](int a, int b) {
251 return weights.weights[a] > weights.weights[b];
252 });
253
254 // Copy top-k indices
255 for (int i = 0; i < k; ++i) {
256 indices[i] = idx[i];
257 }
258
259 // Renormalize probabilities for top-k
260 float sum = 0.0f;
261 for (int i = 0; i < k; ++i) {
262 sum += weights.weights[idx[i]];
263 }
264 for (int i = 0; i < k; ++i) {
265 probs[i] = weights.weights[idx[i]] / sum;
266 }
267 }
268
269private:
270 // =========================================================================
271 // Feature computation
272 // =========================================================================
273
279 static float compute_phase(const ChessInput &input) {
281 }
282
283 static float compute_phase(const FactorizedInput &input) {
285 }
286
291 static float compute_cct(const ChessInput &input) {
293
294 float cct = 0.0f;
295
296 // ========== CHECKS (Weight: 10) ==========
297 bool us_in_check = count_layer(input.layers[L::THEM_CHECKS]) > 0;
298 bool them_in_check = count_layer(input.layers[L::US_CHECKS]) > 0;
299 cct += (us_in_check ? CCT_WEIGHT_CHECK : 0.0f);
300 cct += (them_in_check ? CCT_WEIGHT_CHECK : 0.0f);
301
302 // ========== CAPTURES (Weighted by piece value) ==========
303 // US_THREATS shows where we can capture their pieces
304
305 // Count threatened pieces by type
306 cct += piece_threat_value(input, L::US_THREATS, L::THEM_PAWN,
308 cct += piece_threat_value(input, L::US_THREATS, L::THEM_KNIGHT,
310 cct += piece_threat_value(input, L::US_THREATS, L::THEM_BISHOP,
312 cct += piece_threat_value(input, L::US_THREATS, L::THEM_ROOK,
314 cct += piece_threat_value(input, L::US_THREATS, L::THEM_QUEEN,
316
317 // THEM_THREATS shows where they can capture our pieces
318 cct += piece_threat_value(input, L::THEM_THREATS, L::US_PAWN,
320 cct += piece_threat_value(input, L::THEM_THREATS, L::US_KNIGHT,
322 cct += piece_threat_value(input, L::THEM_THREATS, L::US_BISHOP,
324 cct += piece_threat_value(input, L::THEM_THREATS, L::US_ROOK,
326 cct += piece_threat_value(input, L::THEM_THREATS, L::US_QUEEN,
328
329 // ========== MAJOR THREATS BONUS (Weight: 3) ==========
330 bool us_major_threatened =
331 has_overlap(input.layers[L::THEM_THREATS], input.layers[L::US_ROOK]) ||
332 has_overlap(input.layers[L::THEM_THREATS], input.layers[L::US_QUEEN]);
333 bool them_major_threatened =
334 has_overlap(input.layers[L::US_THREATS], input.layers[L::THEM_ROOK]) ||
335 has_overlap(input.layers[L::US_THREATS], input.layers[L::THEM_QUEEN]);
336
337 cct += (us_major_threatened ? CCT_WEIGHT_MAJOR_THREAT : 0.0f);
338 cct += (them_major_threatened ? CCT_WEIGHT_MAJOR_THREAT : 0.0f);
339
340 return cct;
341 }
342
343 static bool has_overlap_factorized(const FactorizedInput &input,
344 int attacker_group_begin,
345 int attacker_group_end,
346 int target_presence_group) {
348 for (int g = attacker_group_begin; g <= attacker_group_end; ++g) {
349 for (int sq = 0; sq < 64; ++sq) {
350 if (input.branches[g][FE::ATTACKS][sq] > 0.5f &&
351 input.branches[target_presence_group][FE::PRESENCE][sq] > 0.5f) {
352 return true;
353 }
354 }
355 }
356 return false;
357 }
358
359 static float compute_cct(const FactorizedInput &input) {
361
362 float cct = 0.0f;
363
364 const int us_begin = FE::US_PAWN;
365 const int us_end = FE::US_KING;
366 const int them_begin = FE::THEM_PAWN;
367 const int them_end = FE::THEM_KING;
368
369 bool us_in_check = has_overlap_factorized(input, them_begin, them_end,
370 FE::US_KING);
371 bool them_in_check = has_overlap_factorized(input, us_begin, us_end,
372 FE::THEM_KING);
373 cct += (us_in_check ? CCT_WEIGHT_CHECK : 0.0f);
374 cct += (them_in_check ? CCT_WEIGHT_CHECK : 0.0f);
375
376 if (has_overlap_factorized(input, us_begin, us_end, FE::THEM_PAWN))
377 cct += PIECE_VALUE_PAWN;
378 if (has_overlap_factorized(input, us_begin, us_end,
379 FE::THEM_KNIGHT))
380 cct += PIECE_VALUE_KNIGHT;
381 if (has_overlap_factorized(input, us_begin, us_end,
382 FE::THEM_BISHOP))
383 cct += PIECE_VALUE_BISHOP;
384 if (has_overlap_factorized(input, us_begin, us_end, FE::THEM_ROOK))
385 cct += PIECE_VALUE_ROOK;
386 if (has_overlap_factorized(input, us_begin, us_end,
387 FE::THEM_QUEEN))
388 cct += PIECE_VALUE_QUEEN;
389
390 if (has_overlap_factorized(input, them_begin, them_end, FE::US_PAWN))
391 cct += PIECE_VALUE_PAWN;
392 if (has_overlap_factorized(input, them_begin, them_end,
393 FE::US_KNIGHT))
394 cct += PIECE_VALUE_KNIGHT;
395 if (has_overlap_factorized(input, them_begin, them_end,
396 FE::US_BISHOP))
397 cct += PIECE_VALUE_BISHOP;
398 if (has_overlap_factorized(input, them_begin, them_end, FE::US_ROOK))
399 cct += PIECE_VALUE_ROOK;
400 if (has_overlap_factorized(input, them_begin, them_end,
401 FE::US_QUEEN))
402 cct += PIECE_VALUE_QUEEN;
403
404 bool us_major_threatened =
405 has_overlap_factorized(input, them_begin, them_end, FE::US_ROOK) ||
406 has_overlap_factorized(input, them_begin, them_end,
407 FE::US_QUEEN);
408 bool them_major_threatened =
409 has_overlap_factorized(input, us_begin, us_end, FE::THEM_ROOK) ||
410 has_overlap_factorized(input, us_begin, us_end,
411 FE::THEM_QUEEN);
412
413 cct += (us_major_threatened ? CCT_WEIGHT_MAJOR_THREAT : 0.0f);
414 cct += (them_major_threatened ? CCT_WEIGHT_MAJOR_THREAT : 0.0f);
415
416 return cct;
417 }
418
422 static bool has_major_pieces(const ChessInput &input) {
424 return input.global[G::US_MAT_ROOK] > 0.0f ||
425 input.global[G::THEM_MAT_ROOK] > 0.0f ||
426 input.global[G::US_MAT_QUEEN] > 0.0f ||
427 input.global[G::THEM_MAT_QUEEN] > 0.0f;
428 }
429
439 static float compute_major_ratio(const ChessInput &input) {
441
442 // Calculate major piece material (rooks + queens)
443 float major_mat = 0.0f;
444 major_mat +=
445 (input.global[G::US_MAT_ROOK] + input.global[G::THEM_MAT_ROOK]) * 2.0f *
447 major_mat +=
448 (input.global[G::US_MAT_QUEEN] + input.global[G::THEM_MAT_QUEEN]) *
449 1.0f * PIECE_VALUE_QUEEN;
450
451 // Calculate minor piece material (knights + bishops)
452 float minor_mat = 0.0f;
453 minor_mat +=
454 (input.global[G::US_MAT_KNIGHT] + input.global[G::THEM_MAT_KNIGHT]) *
455 2.0f * PIECE_VALUE_KNIGHT;
456 minor_mat +=
457 (input.global[G::US_MAT_BISHOP] + input.global[G::THEM_MAT_BISHOP]) *
458 2.0f * PIECE_VALUE_BISHOP;
459
460 float total = major_mat + minor_mat;
461 if (total < 1e-6f) {
462 return 0.5f; // Pawn endgame - neutral
463 }
464
465 return major_mat / total;
466 }
467
468 static float compute_major_ratio(const FactorizedInput &input) {
470
471 float major_mat = 0.0f;
472 major_mat +=
473 (input.global[G::US_MAT_ROOK] + input.global[G::THEM_MAT_ROOK]) * 2.0f *
475 major_mat +=
476 (input.global[G::US_MAT_QUEEN] + input.global[G::THEM_MAT_QUEEN]) *
477 1.0f * PIECE_VALUE_QUEEN;
478
479 float minor_mat = 0.0f;
480 minor_mat +=
481 (input.global[G::US_MAT_KNIGHT] + input.global[G::THEM_MAT_KNIGHT]) *
482 2.0f * PIECE_VALUE_KNIGHT;
483 minor_mat +=
484 (input.global[G::US_MAT_BISHOP] + input.global[G::THEM_MAT_BISHOP]) *
485 2.0f * PIECE_VALUE_BISHOP;
486
487 float total = major_mat + minor_mat;
488 if (total < 1e-6f) {
489 return 0.5f;
490 }
491
492 return major_mat / total;
493 }
494
498 static float compute_material_imbalance(const ChessInput &input) {
500
501 float us_mat = 0.0f;
502 us_mat += input.global[G::US_MAT_PAWN] * 8.0f * 1.0f;
503 us_mat += input.global[G::US_MAT_KNIGHT] * 2.0f * 3.0f;
504 us_mat += input.global[G::US_MAT_BISHOP] * 2.0f * 3.0f;
505 us_mat += input.global[G::US_MAT_ROOK] * 2.0f * 5.0f;
506 us_mat += input.global[G::US_MAT_QUEEN] * 1.0f * 9.0f;
507
508 float them_mat = 0.0f;
509 them_mat += input.global[G::THEM_MAT_PAWN] * 8.0f * 1.0f;
510 them_mat += input.global[G::THEM_MAT_KNIGHT] * 2.0f * 3.0f;
511 them_mat += input.global[G::THEM_MAT_BISHOP] * 2.0f * 3.0f;
512 them_mat += input.global[G::THEM_MAT_ROOK] * 2.0f * 5.0f;
513 them_mat += input.global[G::THEM_MAT_QUEEN] * 1.0f * 9.0f;
514
515 return std::abs(us_mat - them_mat);
516 }
517
518 // =========================================================================
519 // Utility functions
520 // =========================================================================
521
526 static int count_layer(const uint8_t layer[64]) {
527 const uint64_t *ptr = reinterpret_cast<const uint64_t *>(layer);
528 int count = 0;
529 for (int i = 0; i < 8; ++i) {
530 uint64_t val = ptr[i];
531 if (val) {
532 // Each byte is either 0x00 or 0xFF.
533 // We can just count the set bytes.
534 // Simple way: iterate bytes if the 64-block is non-zero
535 const uint8_t *b = reinterpret_cast<const uint8_t *>(&val);
536 for (int j = 0; j < 8; ++j)
537 if (b[j])
538 count++;
539 }
540 }
541 return count;
542 }
543
548 static bool has_overlap(const uint8_t a[64], const uint8_t b[64]) {
549 const uint64_t *a64 = reinterpret_cast<const uint64_t *>(a);
550 const uint64_t *b64 = reinterpret_cast<const uint64_t *>(b);
551
552 // Unrolled loop for speed
553 if (a64[0] & b64[0])
554 return true;
555 if (a64[1] & b64[1])
556 return true;
557 if (a64[2] & b64[2])
558 return true;
559 if (a64[3] & b64[3])
560 return true;
561 if (a64[4] & b64[4])
562 return true;
563 if (a64[5] & b64[5])
564 return true;
565 if (a64[6] & b64[6])
566 return true;
567 if (a64[7] & b64[7])
568 return true;
569
570 return false;
571 }
572
576 static float piece_threat_value(const ChessInput &input, int threat_layer,
577 int piece_layer, float piece_value) {
578 if (has_overlap(input.layers[threat_layer], input.layers[piece_layer])) {
579 return piece_value;
580 }
581 return 0.0f;
582 }
583
584 // =========================================================================
585 // Math Helpers
586 // =========================================================================
587
593 static inline float fast_exp(float x) {
594 // Clamp to avoid Infinity/NaN artifacts
595 if (x < -88.0f)
596 return 0.0f;
597 if (x > 88.0f)
598 x = 88.0f;
599
600 // Schraudolph's approximation / Magic Polynomial
601 x = 1.0f + x / 256.0f;
602 x *= x;
603 x *= x;
604 x *= x;
605 x *= x;
606 x *= x;
607 x *= x;
608 x *= x;
609 x *= x;
610 return x;
611 }
612
618 static float sigmoid(float x) {
619 if (x > 10.0f)
620 return 1.0f;
621 if (x < -10.0f)
622 return 0.0f;
623 return 1.0f / (1.0f + fast_exp(-x));
624 }
625
630 static void softmax(const float *scores, float *probs, int n) {
631 // 1. Find Max (Shift invariance to prevent overflow)
632 float max_score = scores[0];
633 for (int i = 1; i < n; ++i) {
634 if (scores[i] > max_score)
635 max_score = scores[i];
636 }
637
638 // 2. Exponentiate & Sum
639 float sum = 0.0f;
640 for (int i = 0; i < n; ++i) {
641 // fast_exp guarantees result >= 0
642 float e = fast_exp(scores[i] - max_score);
643 probs[i] = e;
644 sum += e;
645 }
646
647 // 3. SAFETY CHECK
648 // If sum is NaN or effectively zero (should be impossible with max-shift,
649 // but hardware glitches or NaN inputs can cause it), fallback to safety.
650 if (sum < 1e-9f || std::isnan(sum)) {
651 // Reset all to 0
652 for (int i = 0; i < n; ++i)
653 probs[i] = 0.0f;
654
655 // Default: 50% Tactical, 50% Strategic
656 // Assumes TACTICAL=0, STRATEGIC=1 as defined in enum
657 if (n >= 2) {
658 probs[0] = 0.5f;
659 probs[1] = 0.5f;
660 } else {
661 probs[0] = 1.0f;
662 }
663 return;
664 }
665
666 // 4. Normalize
667 float inv_sum = 1.0f / sum;
668 for (int i = 0; i < n; ++i) {
669 probs[i] *= inv_sum;
670 }
671 }
672};
Feature extraction logic for generating packed factorized features.
Standard sparse/dense feature extraction logic for chess positions.
Definition ExpertRouter.hpp:30
static constexpr float KILLER_CP_THRESHOLD
Definition ExpertRouter.hpp:67
static constexpr float AUX_RAMP_RANGE
Definition ExpertRouter.hpp:71
static float compute_phase(const ChessInput &input)
Definition ExpertRouter.hpp:279
static constexpr int NUM_EXPERTS
Definition ExpertRouter.hpp:32
static constexpr float SCORE_MAX
Definition ExpertRouter.hpp:87
static void compute_weights(const FactorizedInput &input, float eval_cp, ExpertWeights &out)
Definition ExpertRouter.hpp:183
static constexpr int NUM_BASE
Definition ExpertRouter.hpp:33
static void compute_weights(const ChessInput &input, float eval_cp, ExpertWeights &out)
Definition ExpertRouter.hpp:106
static float fast_exp(float x)
Definition ExpertRouter.hpp:593
static constexpr float SURVIVOR_CP_THRESHOLD
Definition ExpertRouter.hpp:69
static constexpr float CCT_WEIGHT_CHECK
Definition ExpertRouter.hpp:74
static constexpr float PHASE_ENDGAME_THRESHOLD
Definition ExpertRouter.hpp:60
static constexpr float PIECE_VALUE_ROOK
Definition ExpertRouter.hpp:81
static float compute_major_ratio(const FactorizedInput &input)
Definition ExpertRouter.hpp:468
static constexpr float CCT_STRATEGIC_THRESHOLD
Definition ExpertRouter.hpp:64
static constexpr float SCORE_MIN
Definition ExpertRouter.hpp:88
static constexpr int NUM_AUX
Definition ExpertRouter.hpp:34
static bool has_major_pieces(const ChessInput &input)
Definition ExpertRouter.hpp:422
static float compute_material_imbalance(const ChessInput &input)
Definition ExpertRouter.hpp:498
static float compute_phase(const FactorizedInput &input)
Definition ExpertRouter.hpp:283
static constexpr float PIECE_VALUE_PAWN
Definition ExpertRouter.hpp:78
static float piece_threat_value(const ChessInput &input, int threat_layer, int piece_layer, float piece_value)
Definition ExpertRouter.hpp:576
static float sigmoid(float x)
Definition ExpertRouter.hpp:618
static constexpr float CCT_WEIGHT_MAJOR_THREAT
Definition ExpertRouter.hpp:75
static constexpr float PIECE_VALUE_KNIGHT
Definition ExpertRouter.hpp:79
static float compute_major_ratio(const ChessInput &input)
Definition ExpertRouter.hpp:439
static float compute_cct(const ChessInput &input)
Definition ExpertRouter.hpp:291
static void softmax(const float *scores, float *probs, int n)
Definition ExpertRouter.hpp:630
static void get_top_k(const ExpertWeights &weights, int k, int *indices, float *probs)
Definition ExpertRouter.hpp:243
static float compute_cct(const FactorizedInput &input)
Definition ExpertRouter.hpp:359
static constexpr float PIECE_VALUE_QUEEN
Definition ExpertRouter.hpp:82
static int count_layer(const uint8_t layer[64])
Definition ExpertRouter.hpp:526
static constexpr float CCT_TACTICAL_THRESHOLD
Definition ExpertRouter.hpp:63
static bool has_overlap(const uint8_t a[64], const uint8_t b[64])
Definition ExpertRouter.hpp:548
Expert
Definition ExpertRouter.hpp:37
@ STRATEGIC
Definition ExpertRouter.hpp:40
@ KILLER
Definition ExpertRouter.hpp:45
@ TACTICAL
Definition ExpertRouter.hpp:39
@ MAJOR_END
Definition ExpertRouter.hpp:41
@ MINOR_END
Definition ExpertRouter.hpp:42
@ SURVIVOR
Definition ExpertRouter.hpp:44
static bool has_overlap_factorized(const FactorizedInput &input, int attacker_group_begin, int attacker_group_end, int target_presence_group)
Definition ExpertRouter.hpp:343
static constexpr float PIECE_VALUE_BISHOP
Definition ExpertRouter.hpp:80
Definition FactorizedFeatureExtractor.hpp:55
GlobalIndices
Definition FactorizedFeatureExtractor.hpp:126
@ PHASE
Definition FactorizedFeatureExtractor.hpp:141
GlobalIndices
Definition FeatureExtractor.hpp:72
@ PHASE
Definition FeatureExtractor.hpp:92
LayerIndices
Definition FeatureExtractor.hpp:26
Definition FeatureExtractor.hpp:10
uint8_t layers[32][64]
Definition FeatureExtractor.hpp:13
float global[16]
Definition FeatureExtractor.hpp:17
Definition ExpertRouter.hpp:52
float weights[NUM_EXPERTS]
Definition ExpertRouter.hpp:53
float raw_scores[NUM_EXPERTS]
Definition ExpertRouter.hpp:54
Definition FactorizedFeatureExtractor.hpp:47
float branches[12][MAX_BRANCH_PLANES][64]
Definition FactorizedFeatureExtractor.hpp:50
float global[32]
Definition FactorizedFeatureExtractor.hpp:52