48#include <condition_variable>
75#if defined(__ARM_NEON)
84using Clock = std::chrono::high_resolution_clock;
117#if defined(__clang__)
118#define FORCE_VECTORIZE \
119 _Pragma("clang loop vectorize(enable) interleave(enable)")
120#elif defined(__GNUC__) || defined(__GNUG__)
121#define FORCE_VECTORIZE _Pragma("GCC ivdep")
123#define FORCE_VECTORIZE
127#define HOT_RESTRICT __restrict
129#define HOT_RESTRICT __restrict__
132static inline double us(Clock::time_point a, Clock::time_point b) {
133 return (
double)std::chrono::duration_cast<std::chrono::nanoseconds>(b - a)
147 switch (branch_idx) {
193 Bitboard current_sliders = board.pieces(pt, side);
194 Bitboard old_sliders = old_board.pieces(pt, side);
195 if (current_sliders != old_sliders)
198 Bitboard changed_occupancy = board.occ() ^ old_board.occ();
199 if (!changed_occupancy)
202 Bitboard sliders_copy = current_sliders;
203 while (sliders_copy) {
206 if (pt == PieceType::BISHOP || pt == PieceType::QUEEN)
207 rays |= attacks::bishop(sq, board.occ());
208 if (pt == PieceType::ROOK || pt == PieceType::QUEEN)
209 rays |= attacks::rook(sq, board.occ());
210 if (rays & changed_occupancy)
224 int b = (s == Color::WHITE) ? 0 : 6;
226 if (pt == PieceType::PAWN)
228 else if (pt == PieceType::KNIGHT)
230 else if (pt == PieceType::BISHOP)
232 else if (pt == PieceType::ROOK)
234 else if (pt == PieceType::QUEEN)
236 else if (pt == PieceType::KING)
258 const Board &new_board,
260 std::array<uint8_t, 12> m{};
261 Color mover = old_board.sideToMove();
262 Color enemy = ~mover;
264 Piece p = old_board.at(mv.from());
265 if (p == Piece::NONE) {
271 if (mv.typeOf() == Move::PROMOTION)
273 if (mv.typeOf() == Move::CASTLING)
276 Square cs = Square::NO_SQ;
277 Piece cap = Piece::NONE;
278 if (mv.typeOf() == Move::ENPASSANT) {
279 cs = mv.to().ep_square();
280 cap =
Piece(PieceType::PAWN, enemy);
281 }
else if (old_board.isCapture(mv)) {
283 cap = old_board.at(mv.to());
285 if (cap != Piece::NONE)
289 if (sq == Square::NO_SQ || k == Square::NO_SQ)
291 return static_cast<bool>((attacks::king(k) |
Bitboard(1ULL << k.index())) &
295 Square wk = old_board.kingSq(Color::WHITE);
296 Square bk = old_board.kingSq(Color::BLACK);
297 if (king_touch(mv.from(), wk) || king_touch(mv.to(), wk) ||
298 king_touch(cs, wk) || p.type() == PieceType::KING) {
299 mark_head(m, Color::WHITE, PieceType::KING);
301 if (king_touch(mv.from(), bk) || king_touch(mv.to(), bk) ||
302 king_touch(cs, bk) || p.type() == PieceType::KING) {
303 mark_head(m, Color::BLACK, PieceType::KING);
308 for (
auto side : {Color::WHITE, Color::BLACK}) {
332 const size_t bytes = (size_t)ch * 64 *
sizeof(
float);
333 return std::memcmp(&a.
branches[branch_idx][0][0],
334 &b.
branches[branch_idx][0][0], bytes) != 0;
349 for (
int i = 0; i < n; ++i)
350 dst[i] += src[i] * scale;
358#if defined(__clang__) || defined(__GNUC__)
359 return static_cast<T *
>(__builtin_assume_aligned(ptr, 32));
366#if defined(__clang__) || defined(__GNUC__)
367 return static_cast<const T *
>(__builtin_assume_aligned(ptr, 32));
389 const float wC = wk[4];
391 for (
int sq = 0; sq < 64; ++sq)
392 out_plane[sq] += in_plane[sq] * wC;
395 const float wN = wk[1];
397 for (
int sq = 8; sq < 64; ++sq)
398 out_plane[sq] += in_plane[sq - 8] * wN;
400 const float wS = wk[7];
402 for (
int sq = 0; sq < 56; ++sq)
403 out_plane[sq] += in_plane[sq + 8] * wS;
405 const float wNW = wk[0];
406 for (
int r = 1; r < 8; ++r) {
408 for (
int c = 1; c < 8; ++c)
409 out_plane[r * 8 + c] += in_plane[(r - 1) * 8 + (c - 1)] * wNW;
412 const float wNE = wk[2];
413 for (
int r = 1; r < 8; ++r) {
415 for (
int c = 0; c < 7; ++c)
416 out_plane[r * 8 + c] += in_plane[(r - 1) * 8 + (c + 1)] * wNE;
419 const float wW = wk[3];
420 for (
int r = 0; r < 8; ++r) {
422 for (
int c = 1; c < 8; ++c)
423 out_plane[r * 8 + c] += in_plane[r * 8 + (c - 1)] * wW;
426 const float wE = wk[5];
427 for (
int r = 0; r < 8; ++r) {
429 for (
int c = 0; c < 7; ++c)
430 out_plane[r * 8 + c] += in_plane[r * 8 + (c + 1)] * wE;
433 const float wSW = wk[6];
434 for (
int r = 0; r < 7; ++r) {
436 for (
int c = 1; c < 8; ++c)
437 out_plane[r * 8 + c] += in_plane[(r + 1) * 8 + (c - 1)] * wSW;
440 const float wSE = wk[8];
441 for (
int r = 0; r < 7; ++r) {
443 for (
int c = 0; c < 7; ++c)
444 out_plane[r * 8 + c] += in_plane[(r + 1) * 8 + (c + 1)] * wSE;
461 for (
int sq = 0; sq < 64; ++sq)
464 for (
int i = 0; i < ic; ++i) {
465 const float *wk = &w[(size_t)i * 9];
466 const float *in_plane = &in[(size_t)i * 64];
471 for (
int sq = 0; sq < 64; ++sq)
472 out[sq] = std::max(0.0f, out[sq]);
494 for (
int o = 0; o < oc; ++o) {
495 float *out_plane = &aligned_out[(size_t)o * 64];
497 aligned_b[o], out_plane, ic);
508#if defined(__ARM_NEON)
509static inline float32x4_t neon_fma_n(float32x4_t acc, float32x4_t x,
float w) {
510#if defined(__aarch64__)
511 return vfmaq_n_f32(acc, x, w);
513 return vmlaq_n_f32(acc, x, w);
517static inline void conv3x3_row_accumulate_4oc_l0(
518 const float *w0,
const float *w1,
const float *w2,
const float *w3,
519 const float *in_row,
const float32x4_t &vzero, float32x4_t &a00,
520 float32x4_t &a01, float32x4_t &a10, float32x4_t &a11, float32x4_t &a20,
521 float32x4_t &a21, float32x4_t &a30, float32x4_t &a31,
int kr) {
522 const float32x4_t v_row_0 = vld1q_f32(in_row);
523 const float32x4_t v_row_1 = vld1q_f32(in_row + 4);
524 const float32x4_t v_l0 = vextq_f32(vzero, v_row_0, 3);
525 const float32x4_t v_l1 = vextq_f32(v_row_0, v_row_1, 3);
526 const float32x4_t v_r0 = vextq_f32(v_row_0, v_row_1, 1);
527 const float32x4_t v_r1 = vextq_f32(v_row_1, vzero, 1);
529 const int base = kr * 3;
530 const float wl00 = w0[base + 0], wc00 = w0[base + 1], wr00 = w0[base + 2];
531 const float wl10 = w1[base + 0], wc10 = w1[base + 1], wr10 = w1[base + 2];
532 const float wl20 = w2[base + 0], wc20 = w2[base + 1], wr20 = w2[base + 2];
533 const float wl30 = w3[base + 0], wc30 = w3[base + 1], wr30 = w3[base + 2];
535 a00 = neon_fma_n(a00, v_l0, wl00);
536 a01 = neon_fma_n(a01, v_l1, wl00);
537 a00 = neon_fma_n(a00, v_row_0, wc00);
538 a01 = neon_fma_n(a01, v_row_1, wc00);
539 a00 = neon_fma_n(a00, v_r0, wr00);
540 a01 = neon_fma_n(a01, v_r1, wr00);
542 a10 = neon_fma_n(a10, v_l0, wl10);
543 a11 = neon_fma_n(a11, v_l1, wl10);
544 a10 = neon_fma_n(a10, v_row_0, wc10);
545 a11 = neon_fma_n(a11, v_row_1, wc10);
546 a10 = neon_fma_n(a10, v_r0, wr10);
547 a11 = neon_fma_n(a11, v_r1, wr10);
549 a20 = neon_fma_n(a20, v_l0, wl20);
550 a21 = neon_fma_n(a21, v_l1, wl20);
551 a20 = neon_fma_n(a20, v_row_0, wc20);
552 a21 = neon_fma_n(a21, v_row_1, wc20);
553 a20 = neon_fma_n(a20, v_r0, wr20);
554 a21 = neon_fma_n(a21, v_r1, wr20);
556 a30 = neon_fma_n(a30, v_l0, wl30);
557 a31 = neon_fma_n(a31, v_l1, wl30);
558 a30 = neon_fma_n(a30, v_row_0, wc30);
559 a31 = neon_fma_n(a31, v_row_1, wc30);
560 a30 = neon_fma_n(a30, v_r0, wr30);
561 a31 = neon_fma_n(a31, v_r1, wr30);
564static inline void conv3x3_relu_bd16_neon(
const float *in,
const float *w,
565 const float *b,
float *out,
567 const float32x4_t vzero = vdupq_n_f32(0.0f);
569 for (
int oc = 0; oc < 16; oc += 4) {
570 for (
int r = 0; r < 8; ++r) {
571 float32x4_t a00 = vdupq_n_f32(b[oc + 0]);
572 float32x4_t a01 = vdupq_n_f32(b[oc + 0]);
573 float32x4_t a10 = vdupq_n_f32(b[oc + 1]);
574 float32x4_t a11 = vdupq_n_f32(b[oc + 1]);
575 float32x4_t a20 = vdupq_n_f32(b[oc + 2]);
576 float32x4_t a21 = vdupq_n_f32(b[oc + 2]);
577 float32x4_t a30 = vdupq_n_f32(b[oc + 3]);
578 float32x4_t a31 = vdupq_n_f32(b[oc + 3]);
580 for (
int ic = 0; ic < ic_total; ++ic) {
581 const float *in_plane = &in[(size_t)ic * 64];
582 const float *w0 = &w[((size_t)(oc + 0) * ic_total + (size_t)ic) * 9];
583 const float *w1 = &w[((size_t)(oc + 1) * ic_total + (size_t)ic) * 9];
584 const float *w2 = &w[((size_t)(oc + 2) * ic_total + (size_t)ic) * 9];
585 const float *w3 = &w[((size_t)(oc + 3) * ic_total + (size_t)ic) * 9];
588 conv3x3_row_accumulate_4oc_l0(w0, w1, w2, w3, &in_plane[0], vzero,
589 a00, a01, a10, a11, a20, a21, a30, a31,
591 conv3x3_row_accumulate_4oc_l0(w0, w1, w2, w3, &in_plane[8], vzero,
592 a00, a01, a10, a11, a20, a21, a30, a31,
595 conv3x3_row_accumulate_4oc_l0(w0, w1, w2, w3, &in_plane[48], vzero,
596 a00, a01, a10, a11, a20, a21, a30, a31,
598 conv3x3_row_accumulate_4oc_l0(w0, w1, w2, w3, &in_plane[56], vzero,
599 a00, a01, a10, a11, a20, a21, a30, a31,
602 conv3x3_row_accumulate_4oc_l0(
603 w0, w1, w2, w3, &in_plane[(
size_t)(r - 1) * 8], vzero, a00, a01,
604 a10, a11, a20, a21, a30, a31, 0);
605 conv3x3_row_accumulate_4oc_l0(w0, w1, w2, w3,
606 &in_plane[(
size_t)r * 8], vzero, a00,
607 a01, a10, a11, a20, a21, a30, a31, 1);
608 conv3x3_row_accumulate_4oc_l0(
609 w0, w1, w2, w3, &in_plane[(
size_t)(r + 1) * 8], vzero, a00, a01,
610 a10, a11, a20, a21, a30, a31, 2);
614 const size_t row_base = (size_t)r * 8;
615 vst1q_f32(&out[(
size_t)(oc + 0) * 64 + row_base + 0],
616 vmaxq_f32(vzero, a00));
617 vst1q_f32(&out[(
size_t)(oc + 0) * 64 + row_base + 4],
618 vmaxq_f32(vzero, a01));
619 vst1q_f32(&out[(
size_t)(oc + 1) * 64 + row_base + 0],
620 vmaxq_f32(vzero, a10));
621 vst1q_f32(&out[(
size_t)(oc + 1) * 64 + row_base + 4],
622 vmaxq_f32(vzero, a11));
623 vst1q_f32(&out[(
size_t)(oc + 2) * 64 + row_base + 0],
624 vmaxq_f32(vzero, a20));
625 vst1q_f32(&out[(
size_t)(oc + 2) * 64 + row_base + 4],
626 vmaxq_f32(vzero, a21));
627 vst1q_f32(&out[(
size_t)(oc + 3) * 64 + row_base + 0],
628 vmaxq_f32(vzero, a30));
629 vst1q_f32(&out[(
size_t)(oc + 3) * 64 + row_base + 4],
630 vmaxq_f32(vzero, a31));
654 for (
int c = 0; c < channels; ++c) {
655 const float *in_plane = &in[(size_t)c * 64];
656 float *out_plane = &out[(size_t)c * 64];
657 const float *wk = &w[(size_t)c * 9];
660 for (
int sq = 0; sq < 64; ++sq)
661 out_plane[sq] = b[c];
666 for (
int sq = 0; sq < 64; ++sq)
667 out_plane[sq] = std::max(0.0f, out_plane[sq]);
695 for (
int o = 0; o < oc; ++o) {
696 const float bias = aligned_b[o];
697 float *out_plane = &aligned_out[(size_t)o * 64];
699 for (
int sq = 0; sq < 64; ++sq)
700 out_plane[sq] = bias;
704 for (
int o = 0; o < oc; ++o) {
705 float *out_plane = &aligned_out[(size_t)o * 64];
706 for (
int i = 0; i < ic; ++i) {
707 const float weight = aligned_w[(size_t)o * ic + i];
708 const float *in_plane = &aligned_in[(size_t)i * 64];
710 for (
int sq = 0; sq < 64; ++sq)
711 out_plane[sq] += in_plane[sq] * weight;
716 for (
int o = 0; o < oc; ++o) {
717 float *out_plane = &aligned_out[(size_t)o * 64];
719 for (
int sq = 0; sq < 64; ++sq)
720 out_plane[sq] = std::max(0.0f, out_plane[sq]);
724#if defined(__ARM_NEON)
725static inline void conv1x1_bd16_relu_neon(
const float *
HOT_RESTRICT in,
729 const float32x4_t vzero = vdupq_n_f32(0.0f);
731 for (
int oc = 0; oc < 16; oc += 4) {
732 const float *w0 = &w[(size_t)(oc + 0) * 16];
733 const float *w1 = &w[(size_t)(oc + 1) * 16];
734 const float *w2 = &w[(size_t)(oc + 2) * 16];
735 const float *w3 = &w[(size_t)(oc + 3) * 16];
737 for (
int sq = 0; sq < 64; sq += 4) {
738 float32x4_t acc0 = vdupq_n_f32(b[oc + 0]);
739 float32x4_t acc1 = vdupq_n_f32(b[oc + 1]);
740 float32x4_t acc2 = vdupq_n_f32(b[oc + 2]);
741 float32x4_t acc3 = vdupq_n_f32(b[oc + 3]);
743 for (
int ic = 0; ic < 16; ++ic) {
744 const float32x4_t vin = vld1q_f32(&in[(
size_t)ic * 64 + (
size_t)sq]);
745#if defined(__aarch64__)
746 acc0 = vfmaq_n_f32(acc0, vin, w0[ic]);
747 acc1 = vfmaq_n_f32(acc1, vin, w1[ic]);
748 acc2 = vfmaq_n_f32(acc2, vin, w2[ic]);
749 acc3 = vfmaq_n_f32(acc3, vin, w3[ic]);
751 acc0 = vmlaq_n_f32(acc0, vin, w0[ic]);
752 acc1 = vmlaq_n_f32(acc1, vin, w1[ic]);
753 acc2 = vmlaq_n_f32(acc2, vin, w2[ic]);
754 acc3 = vmlaq_n_f32(acc3, vin, w3[ic]);
758 vst1q_f32(&out[(
size_t)(oc + 0) * 64 + (
size_t)sq],
759 vmaxq_f32(vzero, acc0));
760 vst1q_f32(&out[(
size_t)(oc + 1) * 64 + (
size_t)sq],
761 vmaxq_f32(vzero, acc1));
762 vst1q_f32(&out[(
size_t)(oc + 2) * 64 + (
size_t)sq],
763 vmaxq_f32(vzero, acc2));
764 vst1q_f32(&out[(
size_t)(oc + 3) * 64 + (
size_t)sq],
765 vmaxq_f32(vzero, acc3));
772static inline __m256 avx2_fmadd_ps(__m256 a, __m256 b, __m256 c) {
774 return _mm256_fmadd_ps(a, b, c);
776 return _mm256_add_ps(_mm256_mul_ps(a, b), c);
780static inline void conv1x1_bd16_relu_avx2(
const float *
HOT_RESTRICT in,
784 const __m256 vzero = _mm256_setzero_ps();
785 assert((
reinterpret_cast<uintptr_t
>(in) & 31u) == 0u);
786 assert((
reinterpret_cast<uintptr_t
>(out) & 31u) == 0u);
788 for (
int oc = 0; oc < 16; oc += 4) {
789 const float *w0 = &w[(size_t)(oc + 0) * 16];
790 const float *w1 = &w[(size_t)(oc + 1) * 16];
791 const float *w2 = &w[(size_t)(oc + 2) * 16];
792 const float *w3 = &w[(size_t)(oc + 3) * 16];
794 for (
int sq = 0; sq < 64; sq += 8) {
795 __m256 acc0 = _mm256_set1_ps(b[oc + 0]);
796 __m256 acc1 = _mm256_set1_ps(b[oc + 1]);
797 __m256 acc2 = _mm256_set1_ps(b[oc + 2]);
798 __m256 acc3 = _mm256_set1_ps(b[oc + 3]);
800 for (
int ic = 0; ic < 16; ++ic) {
801 const __m256 vin = _mm256_load_ps(&in[(
size_t)ic * 64 + (
size_t)sq]);
802 acc0 = avx2_fmadd_ps(vin, _mm256_set1_ps(w0[ic]), acc0);
803 acc1 = avx2_fmadd_ps(vin, _mm256_set1_ps(w1[ic]), acc1);
804 acc2 = avx2_fmadd_ps(vin, _mm256_set1_ps(w2[ic]), acc2);
805 acc3 = avx2_fmadd_ps(vin, _mm256_set1_ps(w3[ic]), acc3);
808 _mm256_store_ps(&out[(
size_t)(oc + 0) * 64 + (
size_t)sq],
809 _mm256_max_ps(vzero, acc0));
810 _mm256_store_ps(&out[(
size_t)(oc + 1) * 64 + (
size_t)sq],
811 _mm256_max_ps(vzero, acc1));
812 _mm256_store_ps(&out[(
size_t)(oc + 2) * 64 + (
size_t)sq],
813 _mm256_max_ps(vzero, acc2));
814 _mm256_store_ps(&out[(
size_t)(oc + 3) * 64 + (
size_t)sq],
815 _mm256_max_ps(vzero, acc3));
850 const int r = sq >> 3;
851 const int c = sq & 7;
852 return ((r >> 1) << 2) | (c >> 1);
856 r0 = (region >> 2) * 2;
857 c0 = (region & 3) * 2;
930 alignas(32) std::array<float, NET_BRANCH_DIM>
b{};
942 alignas(32) std::array<float, NET_EXPERT_BOTTLENECK>
bConv{};
948 alignas(32) std::array<
956 alignas(32) std::array<
966 alignas(32) std::array<float, NET_EXPERT_HIDDEN>
bH{};
970 alignas(32) std::array<float, 3>
bWdl{};
999 alignas(64) std::array<
float, (
size_t)
NET_BYPASS *
1001 alignas(64) std::array<float, NET_MIXER_OUT>
mixerB{};
1005 alignas(64) std::array<float, NET_MIXER_OUT>
globalB{};
1007 alignas(64) std::array<
1009 alignas(64) std::array<float, NET_EXPERTS>
gateB{};
1014 auto init_branch = [&](
int ic) {
1019 if (branchConvLayers >= 2) {
1024 if (branchConvLayers >= 3) {
1031 for (
int b = 0; b < 12; ++b)
1084 alignas(64) std::array<std::array<float, kMaxExpertBottleneck>,
1086 alignas(64) std::array<
1095 alignas(64) std::array<std::array<
float, (
size_t)
kMaxBranchDim * 64>,
1097 alignas(64) std::array<std::array<
float, (
size_t)
kMaxBranchDim * 64>,
1132 throw std::runtime_error(
"MoE weights pointer is null");
1148 std::fill(v.begin(), v.end(), 0.0f);
1150 std::fill(v.begin(), v.end(), 0.0f);
1153 std::fill(v.begin(), v.end(), 0.0f);
1155 std::fill(v.begin(), v.end(), 0.0f);
1157 std::fill(v.begin(), v.end(), 0.0f);
1196 template <
typename Fn>
1201 for (
int i = 0; i < n; ++i)
1205 threadPool->parallel_for(n, std::function<
void(
int)>(std::forward<Fn>(fn)));
1214 long long total = 0;
1215 auto add_vec = [&](
const auto &v) { total += (
long long)v.size(); };
1217 for (
const auto &br : w.branches) {
1226 add_vec(w.mixerWBr);
1227 add_vec(w.mixerWBp);
1234 for (
const auto &ex : w.experts) {
1253 if (w.experts.empty())
1256 long long total = 0;
1257 const Expert &ex = w.experts.front();
1258 auto add_vec = [&](
const auto &v) { total += (
long long)v.size(); };
1277 long long total = 0;
1278 auto add_vec = [&](
const auto &v) { total += (
long long)v.size(); };
1280 for (
const auto &ex : w.experts) {
1302 const int k = std::clamp(topk, 0,
nExperts);
1310 throw std::runtime_error(
1311 "Weights dimensions do not match fixed native architecture");
1319 throw std::runtime_error(
"Null shared MoE weights");
1343 std::mt19937 rng(seed);
1344 std::normal_distribution<float> nd(0.0f, 0.05f);
1349 auto fill = [&](
auto &v) {
1354 for (
auto &br : w.branches) {
1375 for (
auto &ex : w.experts) {
1380 for (
int h = 0; h <
eh; ++h) {
1381 for (
int f = 0; f <
ebo; ++f) {
1382 ex.wHGT[(size_t)f *
eh + h] = ex.wHG[(
size_t)h *
ebo + f];
1387 for (
int h = 0; h <
eh; ++h) {
1388 for (
int f = 0; f <
ebo * 64; ++f) {
1389 ex.wHT[(size_t)f *
eh + h] = ex.wH[(
size_t)h * (
ebo * 64) + f];
1394 for (
int h = 0; h <
eh; ++h) {
1396 ex.wH16T[(size_t)f *
eh + h] =
1412#if defined(__ARM_NEON)
1413 conv3x3_relu_bd16_neon(in_planes, br.
l0.
w.data(), br.
l0.
b.data(), mid_plane,
1417 conv1x1_bd16_relu_neon(l1_accum, br.
l2.
w.data(), br.
l2.
b.data(), out);
1418#elif defined(__AVX2__)
1420 mid_plane, br.
l0.
ic);
1423 conv1x1_bd16_relu_avx2(l1_accum, br.
l2.
w.data(), br.
l2.
b.data(), out);
1426 mid_plane, br.
l0.
ic);
1434 float *scratch0,
float *scratch1) {
1442 conv3x3_relu(in_planes, br.l0.w.data(), br.l0.b.data(), out, br.l0.ic,
1450 scratch0, br.l0.ic);
1451 conv3x3_relu(scratch0, br.l1.w.data(), br.l1.b.data(), out, 16, 16);
1453 conv3x3_relu(in_planes, br.l0.w.data(), br.l0.b.data(), scratch0,
1464 conv3x3_relu(in_planes, br.l0.w.data(), br.l0.b.data(), scratch0, br.l0.ic,
1480 const int flat_sz =
ebo * 64;
1482 for (
int out = 0; out <
eh; ++out)
1483 hacc[(
size_t)out] = ex.bH[(size_t)out];
1485 for (
int i = 0; i < flat_sz; ++i) {
1486 const float fv = flat[(size_t)i];
1487 if (std::abs(fv) < 1e-7f)
1489 const float *wt = &ex.wHT[(size_t)i *
eh];
1490 for (
int out = 0; out <
eh; ++out)
1491 hacc[(
size_t)out] += fv * wt[out];
1500 constexpr float inv64 = 1.0f / 64.0f;
1502 for (
int bo = 0; bo <
ebo; ++bo) {
1504 const float *rbo = &flat[(size_t)bo * 64];
1506 for (
int sq = 0; sq < 64; ++sq)
1508 gap[(size_t)bo] = s * inv64;
1511 for (
int out = 0; out <
eh; ++out)
1512 hacc[(
size_t)out] = ex.bH[(size_t)out];
1514 for (
int bo = 0; bo <
ebo; ++bo) {
1515 const float gv = gap[(size_t)bo];
1516 if (std::abs(gv) < 1e-7f)
1518 const float *wt = &ex.wHGT[(size_t)bo *
eh];
1519 for (
int out = 0; out <
eh; ++out)
1520 hacc[(
size_t)out] += gv * wt[out];
1530 for (
int bo = 0; bo <
ebo; ++bo) {
1531 const float *rbo = &flat[(size_t)bo * 64];
1535 const int sq0 = r0 * 8 + c0;
1536 const int sq1 = sq0 + 1;
1537 const int sq2 = sq0 + 8;
1538 const int sq3 = sq2 + 1;
1542 pv = std::max(std::max(rbo[sq0], rbo[sq1]),
1543 std::max(rbo[sq2], rbo[sq3]));
1545 pv = 0.25f * (rbo[sq0] + rbo[sq1] + rbo[sq2] + rbo[sq3]);
1551 for (
int out = 0; out <
eh; ++out)
1552 hacc[(
size_t)out] = ex.bH[(size_t)out];
1555 for (
int i = 0; i < pool_sz; ++i) {
1556 const float pv = pool[(size_t)i];
1557 if (std::abs(pv) < 1e-7f)
1559 const float *wt = &ex.wH16T[(size_t)i *
eh];
1560 for (
int out = 0; out <
eh; ++out)
1561 hacc[(
size_t)out] += pv * wt[out];
1567 float s = shared.globalB[(size_t)oc];
1568 const float *w = &shared.globalW[(size_t)oc *
nGlobals];
1578 float s0 = std::numeric_limits<float>::lowest();
1579 float s1 = std::numeric_limits<float>::lowest();
1582 for (
int e = 0; e <
nExperts; ++e) {
1583 float s = shared.gateB[(size_t)e];
1584 const float *w = &shared.gateW[(size_t)e *
nGlobals];
1586 const int slow_n = std::min(
nGlobals, 15);
1587 for (
int i = 0; i < slow_n; ++i)
1588 s += w[i] * global[i];
1591 s += w[i] * global[i];
1599 }
else if (s > s1) {
1605 const float m = std::max(s0, s1);
1606 const float p0 = std::exp(s0 - m);
1607 const float p1 = std::exp(s1 - m);
1608 const float z = p0 + p1;
1622 for (
int bo = 0; bo <
ebo; ++bo) {
1623 float *pbo = &pre[(size_t)bo * 64];
1625 const float b = ex.bConv[(size_t)bo];
1626 for (
int sq = 0; sq < 64; ++sq)
1629 const float *w = &ex.wConv[(size_t)bo *
nf];
1630 for (
int oc = 0; oc <
nf; ++oc) {
1631 const float wc = w[oc];
1634 for (
int sq = 0; sq < 64; ++sq)
1635 pbo[sq] += in_plane[sq] * wc;
1638 float *rbo = &rel[(size_t)bo * 64];
1639 for (
int sq = 0; sq < 64; ++sq)
1640 rbo[sq] = std::max(0.0f, pbo[sq]);
1666 const int *active_experts,
int active_count) {
1668 const float *w_mixer_br = shared.mixerWBr.data();
1669 const float *w_mixer_bp = shared.mixerWBp.data();
1670 auto tBranch0 = Clock::now();
1673 const float *pin = &inp.
branches[b][0][0];
1679 auto tBranch1 = Clock::now();
1680 profile.fullBranchForwardUs +=
us(tBranch0, tBranch1);
1683 auto tMix0 = Clock::now();
1685 for (
int oc = 0; oc <
nf; ++oc) {
1687 const float bias = shared.mixerB[(size_t)oc];
1688 for (
int sq = 0; sq < 64; ++sq)
1698 for (
int t = 0; t < ocl; ++t)
1701 for (
int b = 0; b < 12; ++b) {
1703 for (
int c = 0; c <
bd; ++c) {
1704 const float *in_plane = &bo[(size_t)c * 64];
1705 for (
int t = 0; t < ocl; ++t) {
1706 const float wc = w_mixer_br[((size_t)b *
nf + (oc0 + t)) *
bd + c];
1712 for (
int bp = 0; bp <
nBypass; ++bp) {
1713 const float *in_plane = &inp.
bypass[bp][0];
1714 for (
int t = 0; t < ocl; ++t) {
1715 const float wc = w_mixer_bp[(size_t)bp *
nf + (oc0 + t)];
1720 auto tMix1 = Clock::now();
1721 profile.fullMixerAccumUs +=
us(tMix0, tMix1);
1724 auto tRelu0 = Clock::now();
1725 for (
int oc = 0; oc <
nf; ++oc) {
1727 for (
int sq = 0; sq < 64; ++sq) {
1732 auto tRelu1 = Clock::now();
1733 profile.fullMixerReluUs +=
us(tRelu0, tRelu1);
1736 auto tEx0 = Clock::now();
1738 if (active_count == 2 && active_experts[0] != active_experts[1]) {
1743 for (
int i = 0; i < active_count; ++i)
1746 auto tEx1 = Clock::now();
1747 profile.fullExpertCacheUs +=
us(tEx0, tEx1);
1774 const int *dirty_branches,
int dirty_count,
1775 const int *active_experts,
int active_count) {
1777 const float *w_mixer_br = shared.mixerWBr.data();
1778 const float *w_mixer_bp = shared.mixerWBp.data();
1784 uint64_t dirty_sq_mask = 0ULL;
1789 auto tA0 = Clock::now();
1791 const int b = dirty_branches[i];
1792 const float *pin = &cur.
branches[b][0][0];
1799 for (
int i = 0; i < dirty_count; ++i) {
1800 const int b = dirty_branches[i];
1801 const float *new_branch =
1805 for (
int c = 0; c <
bd; ++c) {
1806 const float *old_plane = &old_cache[(size_t)c * 64];
1807 const float *new_plane = &new_branch[(size_t)c * 64];
1811 for (
int sq = 0; sq < 64; ++sq)
1812 delta_plane[sq] = new_plane[sq] - old_plane[sq];
1814 for (
int sq = 0; sq < 64; ++sq) {
1815 if (std::abs(delta_plane[sq]) > 1e-9f)
1816 dirty_sq_mask |= (1ULL << sq);
1822 constexpr int kTileOC = 8;
1823 const int ocTiles = (
nf + kTileOC - 1) / kTileOC;
1824 for (
int tile = 0; tile < ocTiles; ++tile) {
1825 const int oc0 = tile * kTileOC;
1826 const int ocl = std::min(kTileOC,
nf - oc0);
1827 for (
int c = 0; c <
bd; ++c) {
1829 for (
int t = 0; t < ocl; ++t) {
1830 const float w = w_mixer_br[((size_t)b *
nf + (oc0 + t)) *
bd + c];
1837 std::memcpy(old_cache, new_branch, (
size_t)
bd * 64 *
sizeof(
float));
1839 auto tA1 = Clock::now();
1840 profile.incBranchDeltaUs +=
us(tA0, tA1);
1843 auto tB0 = Clock::now();
1844 std::array<uint8_t, kMaxBypass> bypass_dirty{};
1845 for (
int bp = 0; bp <
nBypass; ++bp) {
1846 bool has_delta =
false;
1847 for (
int sq = 0; sq < 64; ++sq) {
1848 const float delta = cur.
bypass[bp][sq] - prev.
bypass[bp][sq];
1850 if (std::abs(delta) > 1e-9f) {
1852 dirty_sq_mask |= (1ULL << sq);
1855 bypass_dirty[(size_t)bp] = has_delta ? 1u : 0u;
1860 constexpr int kTileOC = 8;
1861 const int ocTiles = (
nf + kTileOC - 1) / kTileOC;
1862 for (
int tile = 0; tile < ocTiles; ++tile) {
1863 const int oc0 = tile * kTileOC;
1864 const int ocl = std::min(kTileOC,
nf - oc0);
1865 for (
int bp = 0; bp <
nBypass; ++bp) {
1866 if (!bypass_dirty[(
size_t)bp])
1869 for (
int t = 0; t < ocl; ++t) {
1870 const float w = w_mixer_bp[(size_t)bp *
nf + (oc0 + t)];
1877 auto tB1 = Clock::now();
1878 profile.incBypassDeltaUs +=
us(tB0, tB1);
1882 auto tCD0 = Clock::now();
1883 bool globals_changed =
false;
1884 for (
int g = 0; g <
nGlobals; ++g) {
1886 globals_changed =
true;
1890 auto tCD1 = Clock::now();
1891 profile.incGlobalReluUs +=
us(tCD0, tCD1);
1893 if (globals_changed)
1894 dirty_sq_mask = ~0ULL;
1896 for (
int oc = 0; oc <
nf; ++oc)
1900 bool any_mixer_delta =
false;
1901 uint64_t delta_sq_mask = 0ULL;
1902 for (
int oc = 0; oc <
nf; ++oc) {
1904 uint64_t mask = dirty_sq_mask;
1906 const int sq = __builtin_ctzll(mask);
1909 const size_t idx = (size_t)oc * 64 + sq;
1912 const float d = newv - oldv;
1916 any_mixer_delta =
true;
1917 delta_sq_mask |= (1ULL << sq);
1922 int dirty_sq_idx[64];
1923 int dirty_sq_count = 0;
1925 int dirty_region_count = 0;
1926 uint32_t region_mask = 0;
1927 uint64_t work = delta_sq_mask;
1929 const int sq = __builtin_ctzll(work);
1931 dirty_sq_idx[dirty_sq_count++] = sq;
1936 while (region_mask) {
1937 const int region = __builtin_ctz(region_mask);
1938 region_mask &= (region_mask - 1);
1939 dirty_region_idx[dirty_region_count++] = region;
1943 if (!any_mixer_delta) {
1944 std::array<double, kMaxExperts> reb_us{};
1946 const int e = active_experts[i];
1948 const auto tReb0 = Clock::now();
1950 const auto tReb1 = Clock::now();
1951 reb_us[(size_t)i] =
us(tReb0, tReb1);
1954 for (
int i = 0; i < active_count; ++i)
1955 profile.incExpertCacheRebuildUs += reb_us[(
size_t)i];
1959 std::array<uint8_t, kMaxExperts> new_valid{};
1960 std::array<int, kMaxExperts> processed_e{};
1961 std::array<double, kMaxExperts> reb_us{};
1962 std::array<double, kMaxExperts> bott_us{};
1963 std::array<double, kMaxExperts> hidden_us{};
1966 const int e = active_experts[i];
1967 processed_e[(size_t)i] = e;
1970 const auto tReb0 = Clock::now();
1972 const auto tReb1 = Clock::now();
1973 reb_us[(size_t)i] =
us(tReb0, tReb1);
1982 const auto &ex = shared.experts[(size_t)e];
1983 const int flat_sz =
ebo * 64;
1985 std::array<float, kMaxExpertBottleneck> localGapDelta{};
1989 std::fill(localFlatDelta, localFlatDelta + flat_sz, 0.0f);
1991 const auto tBott0 = Clock::now();
1993 constexpr float inv64 = 1.0f / 64.0f;
1994 for (
int bo = 0; bo <
ebo; ++bo) {
1995 float *pbo = &pre[(size_t)bo * 64];
1997 const float *w = &ex.wConv[(size_t)bo *
nf];
1998 for (
int ic = 0; ic <
nf; ++ic) {
1999 const float wc = w[ic];
2002 for (
int sq = 0; sq < 64; ++sq)
2003 pbo[sq] += dr[sq] * wc;
2006 const float *w = &ex.wConv[(size_t)bo *
nf];
2007 for (
int k = 0; k < dirty_sq_count; ++k) {
2008 const int sq = dirty_sq_idx[k];
2010 for (
int ic = 0; ic <
nf; ++ic)
2016 float *rbo = &rel[(size_t)bo * 64];
2017 const float *pbo_ro = &pre[(size_t)bo * 64];
2018 float ch_sum_delta = 0.0f;
2020 for (
int sq = 0; sq < 64; ++sq) {
2021 const float old_flat = rbo[sq];
2022 const float new_flat = std::max(0.0f, pbo_ro[sq]);
2024 ch_sum_delta += (new_flat - old_flat);
2027 for (
int k = 0; k < dirty_sq_count; ++k) {
2028 const int sq = dirty_sq_idx[k];
2029 const float old_flat = rbo[sq];
2030 const float new_flat = std::max(0.0f, pbo_ro[sq]);
2032 ch_sum_delta += (new_flat - old_flat);
2035 const float gd = ch_sum_delta * inv64;
2036 localGapDelta[(size_t)bo] = gd;
2037 gap[(size_t)bo] += gd;
2040 for (
int bo = 0; bo <
ebo; ++bo) {
2041 float *pbo = &pre[(size_t)bo * 64];
2043 const float *w = &ex.wConv[(size_t)bo *
nf];
2044 for (
int ic = 0; ic <
nf; ++ic) {
2045 const float wc = w[ic];
2048 for (
int sq = 0; sq < 64; ++sq)
2049 pbo[sq] += dr[sq] * wc;
2052 const float *w = &ex.wConv[(size_t)bo *
nf];
2053 for (
int k = 0; k < dirty_sq_count; ++k) {
2054 const int sq = dirty_sq_idx[k];
2056 for (
int ic = 0; ic <
nf; ++ic)
2062 float *rbo = &rel[(size_t)bo * 64];
2063 const float *pbo_ro = &pre[(size_t)bo * 64];
2065 for (
int sq = 0; sq < 64; ++sq) {
2066 const float old_flat = rbo[sq];
2067 const float new_flat = std::max(0.0f, pbo_ro[sq]);
2069 localFlatDelta[(size_t)bo * 64 + sq] = new_flat - old_flat;
2072 for (
int k = 0; k < dirty_sq_count; ++k) {
2073 const int sq = dirty_sq_idx[k];
2074 const float old_flat = rbo[sq];
2075 const float new_flat = std::max(0.0f, pbo_ro[sq]);
2077 localFlatDelta[(size_t)bo * 64 + sq] = new_flat - old_flat;
2082 constexpr float inv4 = 0.25f;
2083 for (
int bo = 0; bo <
ebo; ++bo) {
2084 float *pbo = &pre[(size_t)bo * 64];
2086 const float *w = &ex.wConv[(size_t)bo *
nf];
2087 for (
int ic = 0; ic <
nf; ++ic) {
2088 const float wc = w[ic];
2091 for (
int sq = 0; sq < 64; ++sq)
2092 pbo[sq] += dr[sq] * wc;
2095 const float *w = &ex.wConv[(size_t)bo *
nf];
2096 for (
int k = 0; k < dirty_sq_count; ++k) {
2097 const int sq = dirty_sq_idx[k];
2099 for (
int ic = 0; ic <
nf; ++ic)
2105 float *rbo = &rel[(size_t)bo * 64];
2106 const float *pbo_ro = &pre[(size_t)bo * 64];
2108 for (
int sq = 0; sq < 64; ++sq) {
2109 const float old_flat = rbo[sq];
2110 const float new_flat = std::max(0.0f, pbo_ro[sq]);
2114 (new_flat - old_flat) * inv4;
2117 for (
int k = 0; k < dirty_sq_count; ++k) {
2118 const int sq = dirty_sq_idx[k];
2119 const float old_flat = rbo[sq];
2120 const float new_flat = std::max(0.0f, pbo_ro[sq]);
2124 (new_flat - old_flat) * inv4;
2128 const int regions_to_loop =
2130 for (
int rk = 0; rk < regions_to_loop; ++rk) {
2131 const int region = dense_dirty ? rk : dirty_region_idx[rk];
2133 pool16[pidx] += localPoolDelta[pidx];
2137 for (
int bo = 0; bo <
ebo; ++bo) {
2138 float *pbo = &pre[(size_t)bo * 64];
2139 float *rbo = &rel[(size_t)bo * 64];
2141 const float *w = &ex.wConv[(size_t)bo *
nf];
2142 for (
int ic = 0; ic <
nf; ++ic) {
2143 const float wc = w[ic];
2146 for (
int sq = 0; sq < 64; ++sq)
2147 pbo[sq] += dr[sq] * wc;
2150 const float *w = &ex.wConv[(size_t)bo *
nf];
2151 for (
int k = 0; k < dirty_sq_count; ++k) {
2152 const int sq = dirty_sq_idx[k];
2154 for (
int ic = 0; ic <
nf; ++ic)
2161 for (
int sq = 0; sq < 64; ++sq)
2162 rbo[sq] = std::max(0.0f, pbo[sq]);
2164 for (
int k = 0; k < dirty_sq_count; ++k) {
2165 const int sq = dirty_sq_idx[k];
2166 rbo[sq] = std::max(0.0f, pbo[sq]);
2170 const int regions_to_loop =
2172 for (
int rk = 0; rk < regions_to_loop; ++rk) {
2173 const int region = dense_dirty ? rk : dirty_region_idx[rk];
2176 const int sq0 = r0 * 8 + c0;
2177 const int sq1 = sq0 + 1;
2178 const int sq2 = sq0 + 8;
2179 const int sq3 = sq2 + 1;
2180 const float new_pool = std::max(std::max(rbo[sq0], rbo[sq1]),
2181 std::max(rbo[sq2], rbo[sq3]));
2183 const float old_pool = pool16[pidx];
2184 pool16[pidx] = new_pool;
2185 localPoolDelta[pidx] = new_pool - old_pool;
2189 const auto tBott1 = Clock::now();
2190 bott_us[(size_t)i] =
us(tBott0, tBott1);
2192 const auto tHidden0 = Clock::now();
2194 for (
int bo = 0; bo <
ebo; ++bo) {
2195 const float delta = localGapDelta[(size_t)bo];
2196 if (std::abs(delta) < 1e-9f)
2198 const float *wt = &ex.wHGT[(size_t)bo *
eh];
2199 for (
int out = 0; out <
eh; ++out)
2200 hacc[(
size_t)out] += delta * wt[out];
2204 for (
int fi = 0; fi < flat_sz; ++fi) {
2205 const float delta = localFlatDelta[(size_t)fi];
2206 if (std::abs(delta) < 1e-7f)
2208 const float *wt = &ex.wHT[(size_t)fi *
eh];
2210 for (
int out = 0; out <
eh; ++out)
2211 hacc[(
size_t)out] += delta * wt[out];
2214 for (
int bo = 0; bo <
ebo; ++bo) {
2215 for (
int k = 0; k < dirty_sq_count; ++k) {
2216 const int sq = dirty_sq_idx[k];
2217 const int fi = bo * 64 + sq;
2218 const float delta = localFlatDelta[(size_t)fi];
2219 if (std::abs(delta) < 1e-7f)
2221 const float *wt = &ex.wHT[(size_t)fi *
eh];
2222 for (
int out = 0; out <
eh; ++out)
2223 hacc[(
size_t)out] += delta * wt[out];
2228 const int regions_to_loop =
2230 for (
int bo = 0; bo <
ebo; ++bo) {
2231 for (
int rk = 0; rk < regions_to_loop; ++rk) {
2232 const int region = dense_dirty ? rk : dirty_region_idx[rk];
2234 const float delta = localPoolDelta[(size_t)pidx];
2235 if (std::abs(delta) < 1e-9f)
2237 const float *wt = &ex.wH16T[(size_t)pidx *
eh];
2239 for (
int out = 0; out <
eh; ++out)
2240 hacc[(
size_t)out] += delta * wt[out];
2244 const int regions_to_loop =
2246 for (
int bo = 0; bo <
ebo; ++bo) {
2247 for (
int rk = 0; rk < regions_to_loop; ++rk) {
2248 const int region = dense_dirty ? rk : dirty_region_idx[rk];
2250 const float delta = localPoolDelta[(size_t)pidx];
2251 if (std::abs(delta) < 1e-9f)
2253 const float *wt = &ex.wH16T[(size_t)pidx *
eh];
2255 for (
int out = 0; out <
eh; ++out)
2256 hacc[(
size_t)out] += delta * wt[out];
2260 const auto tHidden1 = Clock::now();
2261 hidden_us[(size_t)i] =
us(tHidden0, tHidden1);
2264 for (
int i = 0; i < active_count; ++i) {
2265 profile.incExpertCacheRebuildUs += reb_us[(size_t)i];
2266 profile.incExpertBottleneckUs += bott_us[(size_t)i];
2267 profile.incHiddenDeltaUs += hidden_us[(size_t)i];
2268 new_valid[(size_t)processed_e[(
size_t)i]] = 1;
2285 out_wdl[0] = out_wdl[1] = out_wdl[2] = 0.0f;
2288 const auto &hacc =
hiddenAcc[(size_t)e];
2290 for (
int h = 0; h <
eh; ++h)
2293 float wdl_logits[3] = {ex.bWdl[0], ex.bWdl[1], ex.bWdl[2]};
2294 for (
int o = 0; o < 3; ++o) {
2295 const float *w = &ex.wWdl[(size_t)o *
eh];
2296 for (
int h = 0; h <
eh; ++h)
2301 std::max(wdl_logits[0], std::max(wdl_logits[1], wdl_logits[2]));
2304 for (
int o = 0; o < 3; ++o) {
2305 wdl_prob[o] = std::exp(wdl_logits[o] - mx);
2309 const float inv = 1.0f / s;
2310 for (
float &p : wdl_prob)
2314 for (
int o = 0; o < 3; ++o)
2315 out_wdl[o] = wdl_prob[o];
2332 float wdl0[3] = {0.0f, 0.0f, 0.0f};
2333 float wdl1[3] = {0.0f, 0.0f, 0.0f};
2336 for (
int o = 0; o < 3; ++o)
2337 out_wdl[o] = w0 * wdl0[o] + w1 * wdl1[o];
static const char * expert_pool_mode_name(ExpertPoolMode mode)
Definition MoECacheModel.hpp:835
static constexpr int kDefaultMixerOut
Definition MoECacheModel.hpp:87
static void pool2x2_region_base(int region, int &r0, int &c0)
Definition MoECacheModel.hpp:855
static constexpr int kMaxGlobals
Definition MoECacheModel.hpp:111
static bool is_slider_dirty(const Board &board, const Board &old_board, Color side, PieceType pt)
Checks if a slider piece type (Bishop, Rook, Queen) has become "dirty".
Definition MoECacheModel.hpp:191
static constexpr int NET_EXPERT_HIDDEN
Definition MoECacheModel.hpp:100
static int pool2x2_region_from_sq(int sq)
Definition MoECacheModel.hpp:849
static constexpr int kDefaultExpertHidden
Definition MoECacheModel.hpp:92
static double us(Clock::time_point a, Clock::time_point b)
Definition MoECacheModel.hpp:132
static constexpr int kMaxExpertHidden
Definition MoECacheModel.hpp:114
static void conv1x1_relu(const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, const float *HOT_RESTRICT b, float *HOT_RESTRICT out, int ic, int oc)
Computes a 1x1 convolution (pointwise convolution) across the spatial grid, followed by ReLU.
Definition MoECacheModel.hpp:685
static void simd_add_scaled(float *HOT_RESTRICT dst, const float *HOT_RESTRICT src, float scale, int n)
SIMD-optimized fused multiply-add (FMA) for adding a scaled vector to a destination vector....
Definition MoECacheModel.hpp:345
static constexpr int kMaxExperts
Definition MoECacheModel.hpp:112
ExpertPoolMode
Defines how spatial features are pooled before entering the fully-connected expert networks.
Definition MoECacheModel.hpp:826
@ Pool2x2Max
Definition MoECacheModel.hpp:830
@ Gap
Definition MoECacheModel.hpp:828
@ Flat
Definition MoECacheModel.hpp:827
@ Pool2x2Avg
Definition MoECacheModel.hpp:829
static bool is_slider_branch_rich(int branch_idx)
Determines if a specific branch corresponds to a slider piece type (Bishop, Rook, Queen).
Definition MoECacheModel.hpp:146
static void conv3x3_relu_bd16_dispatch(const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, const float *HOT_RESTRICT b, float *HOT_RESTRICT out, int ic)
Definition MoECacheModel.hpp:501
static constexpr int kInputBypassPlanes
Definition MoECacheModel.hpp:103
static constexpr int NET_GLOBALS
Definition MoECacheModel.hpp:97
static constexpr int kMaxBypass
Definition MoECacheModel.hpp:110
static T * assume_aligned_32(T *ptr)
Hints the compiler that a pointer is aligned to a 32-byte boundary (for AVX operations).
Definition MoECacheModel.hpp:357
static constexpr int kMaxMixerOut
Definition MoECacheModel.hpp:109
static constexpr int kDefaultGlobals
Definition MoECacheModel.hpp:89
static void conv3x3_relu(const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, const float *HOT_RESTRICT b, float *HOT_RESTRICT out, int ic, int oc)
Computes a full 3x3 convolution from ic input channels to oc output channels, followed by ReLU.
Definition MoECacheModel.hpp:485
static constexpr int kDefaultBranchDim
Definition MoECacheModel.hpp:86
#define FORCE_VECTORIZE
Definition MoECacheModel.hpp:123
static void conv3x3_single_out_relu(const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, float b, float *HOT_RESTRICT out, int ic)
Computes a 3x3 convolution across multiple input channels (ic) to produce a single output channel,...
Definition MoECacheModel.hpp:457
static std::array< uint8_t, 12 > build_dirty_mask(const Board &old_board, const Board &new_board, const Move &mv)
Analyzes a chess move to determine exactly which of the 12 spatial branches need recomputing.
Definition MoECacheModel.hpp:257
static constexpr int kPool2x2Regions
Definition MoECacheModel.hpp:833
static void depthwise_conv3x3_relu(const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, const float *HOT_RESTRICT b, float *HOT_RESTRICT out, int channels)
Computes a depthwise 3x3 convolution followed by a ReLU activation.
Definition MoECacheModel.hpp:649
static constexpr int NET_MIXER_OUT
Definition MoECacheModel.hpp:95
static void conv3x3_accumulate_plane(float *HOT_RESTRICT out_plane, const float *HOT_RESTRICT in_plane, const float *HOT_RESTRICT wk)
Computes a full 3x3 2D convolution for a single input channel and accumulates it into out_plane.
Definition MoECacheModel.hpp:385
static constexpr int kMaxBranchDim
Definition MoECacheModel.hpp:108
static constexpr int kInputGlobals
Definition MoECacheModel.hpp:104
static constexpr int kDefaultExpertBottleneck
Definition MoECacheModel.hpp:91
#define HOT_RESTRICT
Definition MoECacheModel.hpp:129
static constexpr int kMaxExpertBottleneck
Definition MoECacheModel.hpp:113
static constexpr int kMixerOcTile
Definition MoECacheModel.hpp:115
std::chrono::high_resolution_clock Clock
Definition MoECacheModel.hpp:84
static constexpr int kDefaultBypass
Definition MoECacheModel.hpp:88
static constexpr int NET_BRANCH_DIM
Definition MoECacheModel.hpp:94
static constexpr int NET_BYPASS
Definition MoECacheModel.hpp:96
static int active_channels_for_branch(int branch_idx)
Returns the number of active input feature channels for a given branch.
Definition MoECacheModel.hpp:167
static void mark_head(std::array< uint8_t, 12 > &m, Color s, PieceType pt)
Marks a specific piece type for a specific color as "dirty" in the dirty mask.
Definition MoECacheModel.hpp:223
static constexpr int NET_EXPERTS
Definition MoECacheModel.hpp:98
static constexpr int kDefaultExperts
Definition MoECacheModel.hpp:90
static constexpr int NET_EXPERT_BOTTLENECK
Definition MoECacheModel.hpp:99
static bool branch_planes_changed(const FactorizedInput &a, const FactorizedInput &b, int branch_idx)
Performs a fast byte-level comparison to see if two spatial branch inputs are identical.
Definition MoECacheModel.hpp:328
chess::Move Move
Alias for chess::Move.
Definition Types.h:15
chess::PieceType PieceType
Alias for chess::PieceType.
Definition Types.h:20
chess::Color Color
Alias for chess::Color.
Definition Types.h:17
chess::Piece Piece
Alias for chess::Piece.
Definition Types.h:19
chess::Bitboard Bitboard
Alias for chess::Bitboard.
Definition Types.h:21
chess::Square Square
Alias for chess::Square.
Definition Types.h:18
chess::Board Board
Alias for chess::Board.
Definition Types.h:14
Definition MoECacheModel.hpp:860
int nGlobals
Definition MoECacheModel.hpp:875
unsigned seed
Definition MoECacheModel.hpp:863
int branchConvLayers
Definition MoECacheModel.hpp:871
int minParallelActiveExperts
Definition MoECacheModel.hpp:866
int minParallelDirtyHeads
Definition MoECacheModel.hpp:865
int nThreads
Definition MoECacheModel.hpp:864
int nBypass
Definition MoECacheModel.hpp:874
int expertHidden
Definition MoECacheModel.hpp:878
int nExperts
Definition MoECacheModel.hpp:876
int nPlies
Definition MoECacheModel.hpp:862
int denseDirtySqThreshold
Definition MoECacheModel.hpp:867
ExpertPoolMode expertPoolMode
Definition MoECacheModel.hpp:868
int expertBottleneck
Definition MoECacheModel.hpp:877
int mixerOut
Definition MoECacheModel.hpp:873
int branchDim
Definition MoECacheModel.hpp:872
bool routeSlowGlobals
Definition MoECacheModel.hpp:869
int nGames
Definition MoECacheModel.hpp:861
Definition MoECacheModel.hpp:881
double incRouteUs
Definition MoECacheModel.hpp:899
double incBypassDeltaUs
Definition MoECacheModel.hpp:911
double incBranchDeltaUs
Definition MoECacheModel.hpp:910
double incExpertCacheRebuildUs
Definition MoECacheModel.hpp:915
double incExpertBottleneckUs
Definition MoECacheModel.hpp:913
double fullUs
Definition MoECacheModel.hpp:883
double incUs
Definition MoECacheModel.hpp:885
long long incPlies
Definition MoECacheModel.hpp:886
long long incEvents
Definition MoECacheModel.hpp:888
double prepFeatureUs
Definition MoECacheModel.hpp:892
double fullUpdateUs
Definition MoECacheModel.hpp:897
double incExpertUs
Definition MoECacheModel.hpp:901
long long fullPlies
Definition MoECacheModel.hpp:884
double incGlobalReluUs
Definition MoECacheModel.hpp:912
double incHiddenDeltaUs
Definition MoECacheModel.hpp:914
double avgDirtyHeads() const
Definition MoECacheModel.hpp:921
std::string name
Definition MoECacheModel.hpp:882
double fullExpertCacheUs
Definition MoECacheModel.hpp:907
double fullNps() const
Definition MoECacheModel.hpp:917
double fullMixerAccumUs
Definition MoECacheModel.hpp:905
double fullRouteUs
Definition MoECacheModel.hpp:896
long long incDirtyHeads
Definition MoECacheModel.hpp:887
double prepPlayUs
Definition MoECacheModel.hpp:891
double fullBranchForwardUs
Definition MoECacheModel.hpp:904
double incUpdateUs
Definition MoECacheModel.hpp:900
double fullExpertUs
Definition MoECacheModel.hpp:898
double fullMixerReluUs
Definition MoECacheModel.hpp:906
double incNps() const
Definition MoECacheModel.hpp:920
double prepDirtyUs
Definition MoECacheModel.hpp:893
Definition MoECacheModel.hpp:926
std::array< float, NET_BRANCH_DIM > b
Definition MoECacheModel.hpp:930
int oc
Definition MoECacheModel.hpp:928
int ic
Definition MoECacheModel.hpp:927
std::array< float,(size_t) 5 *NET_BRANCH_DIM *9 > w
Definition MoECacheModel.hpp:929
Definition MoECacheModel.hpp:933
BranchLayer l1
Definition MoECacheModel.hpp:935
BranchLayer l2
Definition MoECacheModel.hpp:936
BranchLayer l0
Definition MoECacheModel.hpp:934
Definition MoECacheModel.hpp:939
std::array< float, 3 > bWdl
Definition MoECacheModel.hpp:970
std::array< float,(size_t) NET_EXPERT_BOTTLENECK *NET_EXPERT_HIDDEN > wHGT
Definition MoECacheModel.hpp:950
std::array< float,(size_t) NET_EXPERT_HIDDEN *NET_EXPERT_BOTTLENECK > wHG
Definition MoECacheModel.hpp:946
std::array< float,(size_t) NET_EXPERT_HIDDEN *NET_EXPERT_BOTTLENECK *kPool2x2Regions > wH16
Definition MoECacheModel.hpp:954
std::array< float,(size_t) NET_EXPERT_BOTTLENECK *64 *NET_EXPERT_HIDDEN > wHT
Definition MoECacheModel.hpp:964
std::array< float, NET_EXPERT_BOTTLENECK > bConv
Definition MoECacheModel.hpp:942
std::array< float,(size_t) NET_EXPERT_BOTTLENECK *NET_MIXER_OUT > wConv
Definition MoECacheModel.hpp:941
std::array< float,(size_t) NET_EXPERT_BOTTLENECK *kPool2x2Regions *NET_EXPERT_HIDDEN > wH16T
Definition MoECacheModel.hpp:958
std::array< float,(size_t) 3 *NET_EXPERT_HIDDEN > wWdl
Definition MoECacheModel.hpp:969
std::array< float, NET_EXPERT_HIDDEN > bH
Definition MoECacheModel.hpp:966
std::array< float,(size_t) NET_EXPERT_HIDDEN *NET_EXPERT_BOTTLENECK *64 > wH
Definition MoECacheModel.hpp:962
Definition MoECacheModel.hpp:1112
double fullExpertCacheUs
Definition MoECacheModel.hpp:1116
double incExpertBottleneckUs
Definition MoECacheModel.hpp:1121
double incGlobalReluUs
Definition MoECacheModel.hpp:1120
double incBranchDeltaUs
Definition MoECacheModel.hpp:1118
double incBypassDeltaUs
Definition MoECacheModel.hpp:1119
double fullBranchForwardUs
Definition MoECacheModel.hpp:1113
double incHiddenDeltaUs
Definition MoECacheModel.hpp:1122
double fullMixerAccumUs
Definition MoECacheModel.hpp:1114
double incExpertCacheRebuildUs
Definition MoECacheModel.hpp:1123
double fullMixerReluUs
Definition MoECacheModel.hpp:1115
Thread-local state for incremental, lightning-fast neural network inference.
Definition MoECacheModel.hpp:1051
static constexpr int nExperts
Definition MoECacheModel.hpp:1063
void run_top2_experts(int e0, int e1, float w0, float w1, float out_wdl[3])
Combines the output of the top 2 routed experts based on their routing weights.
Definition MoECacheModel.hpp:2331
std::unique_ptr< PersistentThreadPool > threadPool
Definition MoECacheModel.hpp:1070
static constexpr int nGlobals
Definition MoECacheModel.hpp:1062
std::array< float,(size_t) kMaxBranchDim *64 > scratchBranchDelta
Definition MoECacheModel.hpp:1102
SharedMoEWeights & mutable_owned_weights()
Definition MoECacheModel.hpp:1136
std::array< std::array< float, kMaxExpertHidden >, kMaxExperts > hiddenAcc
Definition MoECacheModel.hpp:1083
static constexpr int bd
Definition MoECacheModel.hpp:1059
std::array< std::array< float,(size_t) kMaxExpertBottleneck *64 >, kMaxExperts > exPreAccum
Definition MoECacheModel.hpp:1078
std::shared_ptr< SharedMoEWeights > ownedWeights
Definition MoECacheModel.hpp:1068
std::array< std::array< float,(size_t) kMaxExpertBottleneck *64 >, kMaxExperts > exReluCache
Definition MoECacheModel.hpp:1080
static void validate_fixed_architecture(const BenchConfig &cfg)
Definition MoECacheModel.hpp:1306
void rebuild_hidden_acc_from_pool2x2(int e, bool max_pool)
Definition MoECacheModel.hpp:1524
void branch_forward_bd16_fast(const Branch &br, const float *HOT_RESTRICT in_planes, float *HOT_RESTRICT out, float *HOT_RESTRICT mid_plane, float *HOT_RESTRICT l1_accum)
Definition MoECacheModel.hpp:1407
void fill_random(unsigned seed)
Definition MoECacheModel.hpp:1342
std::array< float,(size_t) kMaxMixerOut *64 > mixerReluCache
Definition MoECacheModel.hpp:1075
std::array< uint8_t, kMaxExperts > exValid
Definition MoECacheModel.hpp:1081
void update_incremental(const FactorizedInput &cur, const FactorizedInput &prev, const int *dirty_branches, int dirty_count, const int *active_experts, int active_count)
Performs an incremental network update by computing and applying only the differences.
Definition MoECacheModel.hpp:1772
std::array< float, kMaxMixerOut > scratchGproj
Definition MoECacheModel.hpp:1104
std::array< float,(size_t) kMaxBranchDim *64 > scratchT0
Definition MoECacheModel.hpp:1092
void rebuild_hidden_acc_from_gap(int e)
Definition MoECacheModel.hpp:1495
float global_proj_at(int oc, const float *g) const
Definition MoECacheModel.hpp:1565
std::array< float,(size_t) 12 *kMaxBranchDim *64 > scratchDirtyBranches
Definition MoECacheModel.hpp:1100
std::array< float,(size_t) kMaxMixerOut *64 > mixerLinearAccum
Definition MoECacheModel.hpp:1074
int minParallelDirtyHeads
Definition MoECacheModel.hpp:1054
std::array< float,(size_t) 12 *kMaxBranchDim *64 > branchCache
Definition MoECacheModel.hpp:1073
long long total_weights() const
Definition MoECacheModel.hpp:1212
void branch_forward_with_scratch(int b, const float *in_planes, float *out, float *scratch0, float *scratch1)
Definition MoECacheModel.hpp:1433
int nThreads
Definition MoECacheModel.hpp:1053
std::array< std::array< float,(size_t) kMaxBranchDim *64 >, 12 > scratchParallelBranch1
Definition MoECacheModel.hpp:1098
void init(const BenchConfig &cfg)
Definition MoECacheModel.hpp:1336
std::array< std::array< float,(size_t) kMaxExpertBottleneck *64 >, kMaxExperts > scratchParallelExpertDelta
Definition MoECacheModel.hpp:1109
void top2_experts(const float *global, int &e0, int &e1, float &w0, float &w1) const
Definition MoECacheModel.hpp:1574
ExpertPoolMode expertPoolMode
Definition MoECacheModel.hpp:1057
std::array< float,(size_t) kMaxBypass *64 > scratchBypassDelta
Definition MoECacheModel.hpp:1103
void branch_forward(int b, const float *in_planes, float *out)
Definition MoECacheModel.hpp:1471
void parallel_for_indices(int n, int min_parallel_n, Fn &&fn)
Definition MoECacheModel.hpp:1197
void reset_runtime_state()
Definition MoECacheModel.hpp:1143
bool routeSlowGlobals
Definition MoECacheModel.hpp:1058
std::array< float, kMaxExpertHidden > scratchHidden
Definition MoECacheModel.hpp:1110
PhaseProfile profile
Definition MoECacheModel.hpp:1126
std::array< float,(size_t) kMaxExpertBottleneck *64 > scratchFlatDelta
Definition MoECacheModel.hpp:1107
long long experts_total_weights() const
Definition MoECacheModel.hpp:1275
static constexpr int nf
Definition MoECacheModel.hpp:1060
std::array< std::array< float,(size_t) kMaxBranchDim *64 >, 12 > scratchParallelBranch0
Definition MoECacheModel.hpp:1096
static constexpr int nBypass
Definition MoECacheModel.hpp:1061
std::array< float, kMaxGlobals > oldGlobalV
Definition MoECacheModel.hpp:1089
void parallel_for_indices(int n, Fn &&fn)
Definition MoECacheModel.hpp:1208
static constexpr int eh
Definition MoECacheModel.hpp:1065
const SharedMoEWeights * weights
Definition MoECacheModel.hpp:1067
void copy_weights_from(const MoEDoubleAccumulator &src)
Definition MoECacheModel.hpp:1179
void reset_profile()
Definition MoECacheModel.hpp:1194
void run_active_expert(int e, float out_wdl[3])
Computes the final hidden layer and WDL output for a single expert.
Definition MoECacheModel.hpp:2284
long long single_expert_weights() const
Definition MoECacheModel.hpp:1251
std::array< float,(size_t) kMaxMixerOut *64 > scratchDeltaRelu
Definition MoECacheModel.hpp:1105
void rebuild_hidden_acc_from_flat(int e)
Definition MoECacheModel.hpp:1476
int branchConvLayers
Definition MoECacheModel.hpp:1052
std::array< float,(size_t) kMaxBranchDim *64 > scratchNewBranch
Definition MoECacheModel.hpp:1094
static constexpr int ebo
Definition MoECacheModel.hpp:1064
int denseDirtySqThreshold
Definition MoECacheModel.hpp:1056
const SharedMoEWeights & shared_weights() const
Definition MoECacheModel.hpp:1130
int minParallelActiveExperts
Definition MoECacheModel.hpp:1055
std::array< float,(size_t) kMaxBranchDim *64 > scratchT1
Definition MoECacheModel.hpp:1093
void rebuild_expert_cache_from_mixer(int e)
Definition MoECacheModel.hpp:1618
void init(const SharedMoEWeights *shared, const BenchConfig &cfg)
Definition MoECacheModel.hpp:1315
std::array< std::array< float, kMaxExpertBottleneck >, kMaxExperts > exGapCache
Definition MoECacheModel.hpp:1085
bool initialized
Definition MoECacheModel.hpp:1128
std::array< std::array< float,(size_t) kMaxExpertBottleneck *kPool2x2Regions >, kMaxExperts > exPool16Cache
Definition MoECacheModel.hpp:1088
void full_rebuild_accumulators(const FactorizedInput &inp, const int *active_experts, int active_count)
Performs a full forward pass of the model, discarding all cache.
Definition MoECacheModel.hpp:1665
long long backbone_weights() const
Definition MoECacheModel.hpp:1297
long long runtime_topk_weights(int topk) const
Definition MoECacheModel.hpp:1301
Contains the globally shared, read-only weights for the Factorized MoE network.
Definition MoECacheModel.hpp:986
std::array< float,(size_t) NET_EXPERTS *NET_GLOBALS > gateW
Definition MoECacheModel.hpp:1008
std::array< float, NET_MIXER_OUT > globalB
Definition MoECacheModel.hpp:1005
std::array< float,(size_t) NET_MIXER_OUT *NET_GLOBALS > globalW
Definition MoECacheModel.hpp:1004
std::array< float, NET_MIXER_OUT > mixerB
Definition MoECacheModel.hpp:1001
std::array< float,(size_t) 12 *NET_MIXER_OUT *NET_BRANCH_DIM > mixerWBr
Definition MoECacheModel.hpp:998
static constexpr int bd
Definition MoECacheModel.hpp:987
static constexpr int nf
Definition MoECacheModel.hpp:988
static constexpr int nExperts
Definition MoECacheModel.hpp:991
std::array< float,(size_t) NET_BYPASS *NET_MIXER_OUT > mixerWBp
Definition MoECacheModel.hpp:1000
void init_architecture(int branchConvLayers)
Definition MoECacheModel.hpp:1013
static constexpr int ebo
Definition MoECacheModel.hpp:992
static constexpr int eh
Definition MoECacheModel.hpp:993
std::array< Expert, NET_EXPERTS > experts
Definition MoECacheModel.hpp:1011
std::array< float, NET_EXPERTS > gateB
Definition MoECacheModel.hpp:1009
static constexpr int nGlobals
Definition MoECacheModel.hpp:990
std::array< Branch, 12 > branches
Definition MoECacheModel.hpp:995
static constexpr int nBypass
Definition MoECacheModel.hpp:989