io-chess
UCI chess engine
Loading...
Searching...
No Matches
MoECacheModel.hpp
Go to the documentation of this file.
1
40#pragma once
41
42#include <algorithm>
43#include <array>
44#include <atomic>
45#include <cassert>
46#include <chrono>
47#include <cmath>
48#include <condition_variable>
49#include <cstdint>
50#include <cstdlib>
51#include <cstring>
52#include <functional>
53#include <iomanip>
54#include <iostream>
55#include <limits>
56#include <memory>
57#include <mutex>
58#include <queue>
59#include <random>
60#include <string>
61#include <thread>
62#include <type_traits>
63#include <utility>
64#include <vector>
65
66#if defined(__SSE__)
67#include <xmmintrin.h>
68#endif
69#if defined(__SSE3__)
70#include <pmmintrin.h>
71#endif
72#if defined(__AVX2__)
73#include <immintrin.h>
74#endif
75#if defined(__ARM_NEON)
76#include <arm_neon.h>
77#endif
78
79#include <chess.hpp>
80
82
83using namespace chess;
84using Clock = std::chrono::high_resolution_clock;
85
86static constexpr int kDefaultBranchDim = 16;
87static constexpr int kDefaultMixerOut = 64;
88static constexpr int kDefaultBypass = 12;
89static constexpr int kDefaultGlobals = 21;
90static constexpr int kDefaultExperts = 4;
91static constexpr int kDefaultExpertBottleneck = 32;
92static constexpr int kDefaultExpertHidden = 128;
93
94static constexpr int NET_BRANCH_DIM = 16;
95static constexpr int NET_MIXER_OUT = 64;
96static constexpr int NET_BYPASS = 12;
97static constexpr int NET_GLOBALS = 21;
98static constexpr int NET_EXPERTS = 4;
99static constexpr int NET_EXPERT_BOTTLENECK = 32;
100static constexpr int NET_EXPERT_HIDDEN = 128;
101
102// FactorizedInput currently exposes 12 bypass planes and 32 global scalars.
103static constexpr int kInputBypassPlanes = 12;
104static constexpr int kInputGlobals = 32;
105
106// Upper bounds for runtime-configurable dimensions while keeping static aligned
107// storage in the hot path.
108static constexpr int kMaxBranchDim = NET_BRANCH_DIM;
109static constexpr int kMaxMixerOut = NET_MIXER_OUT;
110static constexpr int kMaxBypass = NET_BYPASS;
111static constexpr int kMaxGlobals = NET_GLOBALS;
112static constexpr int kMaxExperts = NET_EXPERTS;
114static constexpr int kMaxExpertHidden = NET_EXPERT_HIDDEN;
115static constexpr int kMixerOcTile = 8;
116
117#if defined(__clang__)
118#define FORCE_VECTORIZE \
119 _Pragma("clang loop vectorize(enable) interleave(enable)")
120#elif defined(__GNUC__) || defined(__GNUG__)
121#define FORCE_VECTORIZE _Pragma("GCC ivdep")
122#else
123#define FORCE_VECTORIZE
124#endif
125
126#if defined(_MSC_VER)
127#define HOT_RESTRICT __restrict
128#else
129#define HOT_RESTRICT __restrict__
130#endif
131
132static inline double us(Clock::time_point a, Clock::time_point b) {
133 return (double)std::chrono::duration_cast<std::chrono::nanoseconds>(b - a)
134 .count() /
135 1000.0;
136}
137
139
146static inline bool is_slider_branch_rich(int branch_idx) {
147 switch (branch_idx) {
148 case 2:
149 case 3:
150 case 4:
151 case 8:
152 case 9:
153 case 10:
154 return true;
155 default:
156 return false;
157 }
158}
159
167static inline int active_channels_for_branch(int branch_idx) {
168 // Rich feature policy:
169 // - Sliders (B/R/Q): 5 planes (includes X-ray)
170 // - Knights: 4 planes
171 // - Pawns/Kings: 4 planes
172 if (is_slider_branch_rich(branch_idx))
173 return 5;
174 return 4;
175}
176
191static bool is_slider_dirty(const Board &board, const Board &old_board,
192 Color side, PieceType pt) {
193 Bitboard current_sliders = board.pieces(pt, side);
194 Bitboard old_sliders = old_board.pieces(pt, side);
195 if (current_sliders != old_sliders)
196 return true;
197
198 Bitboard changed_occupancy = board.occ() ^ old_board.occ();
199 if (!changed_occupancy)
200 return false;
201
202 Bitboard sliders_copy = current_sliders;
203 while (sliders_copy) {
204 Square sq = Square(sliders_copy.pop());
205 Bitboard rays = 0;
206 if (pt == PieceType::BISHOP || pt == PieceType::QUEEN)
207 rays |= attacks::bishop(sq, board.occ());
208 if (pt == PieceType::ROOK || pt == PieceType::QUEEN)
209 rays |= attacks::rook(sq, board.occ());
210 if (rays & changed_occupancy)
211 return true;
212 }
213 return false;
214}
215
223static void mark_head(std::array<uint8_t, 12> &m, Color s, PieceType pt) {
224 int b = (s == Color::WHITE) ? 0 : 6;
225 int i = -1;
226 if (pt == PieceType::PAWN)
227 i = b;
228 else if (pt == PieceType::KNIGHT)
229 i = b + 1;
230 else if (pt == PieceType::BISHOP)
231 i = b + 2;
232 else if (pt == PieceType::ROOK)
233 i = b + 3;
234 else if (pt == PieceType::QUEEN)
235 i = b + 4;
236 else if (pt == PieceType::KING)
237 i = b + 5;
238 if (i >= 0)
239 m[(size_t)i] = 1;
240}
241
257static std::array<uint8_t, 12> build_dirty_mask(const Board &old_board,
258 const Board &new_board,
259 const Move &mv) {
260 std::array<uint8_t, 12> m{};
261 Color mover = old_board.sideToMove();
262 Color enemy = ~mover;
263
264 Piece p = old_board.at(mv.from());
265 if (p == Piece::NONE) {
266 m.fill(1);
267 return m;
268 }
269
270 mark_head(m, mover, p.type());
271 if (mv.typeOf() == Move::PROMOTION)
272 mark_head(m, mover, mv.promotionType());
273 if (mv.typeOf() == Move::CASTLING)
274 mark_head(m, mover, PieceType::ROOK);
275
276 Square cs = Square::NO_SQ;
277 Piece cap = Piece::NONE;
278 if (mv.typeOf() == Move::ENPASSANT) {
279 cs = mv.to().ep_square();
280 cap = Piece(PieceType::PAWN, enemy);
281 } else if (old_board.isCapture(mv)) {
282 cs = mv.to();
283 cap = old_board.at(mv.to());
284 }
285 if (cap != Piece::NONE)
286 mark_head(m, enemy, cap.type());
287
288 auto king_touch = [](Square sq, Square k) {
289 if (sq == Square::NO_SQ || k == Square::NO_SQ)
290 return true;
291 return static_cast<bool>((attacks::king(k) | Bitboard(1ULL << k.index())) &
292 Bitboard(1ULL << sq.index()));
293 };
294
295 Square wk = old_board.kingSq(Color::WHITE);
296 Square bk = old_board.kingSq(Color::BLACK);
297 if (king_touch(mv.from(), wk) || king_touch(mv.to(), wk) ||
298 king_touch(cs, wk) || p.type() == PieceType::KING) {
299 mark_head(m, Color::WHITE, PieceType::KING);
300 }
301 if (king_touch(mv.from(), bk) || king_touch(mv.to(), bk) ||
302 king_touch(cs, bk) || p.type() == PieceType::KING) {
303 mark_head(m, Color::BLACK, PieceType::KING);
304 }
305
306 // Tight slider invalidation: mark only if slider moved/captured or ray
307 // changed.
308 for (auto side : {Color::WHITE, Color::BLACK}) {
309 if (is_slider_dirty(new_board, old_board, side, PieceType::BISHOP))
310 mark_head(m, side, PieceType::BISHOP);
311 if (is_slider_dirty(new_board, old_board, side, PieceType::ROOK))
312 mark_head(m, side, PieceType::ROOK);
313 if (is_slider_dirty(new_board, old_board, side, PieceType::QUEEN))
314 mark_head(m, side, PieceType::QUEEN);
315 }
316
317 return m;
318}
319
328static inline bool branch_planes_changed(const FactorizedInput &a,
329 const FactorizedInput &b,
330 int branch_idx) {
331 const int ch = active_channels_for_branch(branch_idx);
332 const size_t bytes = (size_t)ch * 64 * sizeof(float);
333 return std::memcmp(&a.branches[branch_idx][0][0],
334 &b.branches[branch_idx][0][0], bytes) != 0;
335}
336
345static inline void simd_add_scaled(float *HOT_RESTRICT dst,
346 const float *HOT_RESTRICT src, float scale,
347 int n) {
349 for (int i = 0; i < n; ++i)
350 dst[i] += src[i] * scale;
351}
352
357template <typename T> static inline T *assume_aligned_32(T *ptr) {
358#if defined(__clang__) || defined(__GNUC__)
359 return static_cast<T *>(__builtin_assume_aligned(ptr, 32));
360#else
361 return ptr;
362#endif
363}
364
365template <typename T> static inline const T *assume_aligned_32(const T *ptr) {
366#if defined(__clang__) || defined(__GNUC__)
367 return static_cast<const T *>(__builtin_assume_aligned(ptr, 32));
368#else
369 return ptr;
370#endif
371}
372
385static inline void conv3x3_accumulate_plane(float *HOT_RESTRICT out_plane,
386 const float *HOT_RESTRICT in_plane,
387 const float *HOT_RESTRICT wk) {
388 // Center (64 contiguous cells)
389 const float wC = wk[4];
391 for (int sq = 0; sq < 64; ++sq)
392 out_plane[sq] += in_plane[sq] * wC;
393
394 // North/South as contiguous 1-D slices to avoid short-tail paths.
395 const float wN = wk[1];
397 for (int sq = 8; sq < 64; ++sq)
398 out_plane[sq] += in_plane[sq - 8] * wN;
399
400 const float wS = wk[7];
402 for (int sq = 0; sq < 56; ++sq)
403 out_plane[sq] += in_plane[sq + 8] * wS;
404
405 const float wNW = wk[0];
406 for (int r = 1; r < 8; ++r) {
408 for (int c = 1; c < 8; ++c)
409 out_plane[r * 8 + c] += in_plane[(r - 1) * 8 + (c - 1)] * wNW;
410 }
411
412 const float wNE = wk[2];
413 for (int r = 1; r < 8; ++r) {
415 for (int c = 0; c < 7; ++c)
416 out_plane[r * 8 + c] += in_plane[(r - 1) * 8 + (c + 1)] * wNE;
417 }
418
419 const float wW = wk[3];
420 for (int r = 0; r < 8; ++r) {
422 for (int c = 1; c < 8; ++c)
423 out_plane[r * 8 + c] += in_plane[r * 8 + (c - 1)] * wW;
424 }
425
426 const float wE = wk[5];
427 for (int r = 0; r < 8; ++r) {
429 for (int c = 0; c < 7; ++c)
430 out_plane[r * 8 + c] += in_plane[r * 8 + (c + 1)] * wE;
431 }
432
433 const float wSW = wk[6];
434 for (int r = 0; r < 7; ++r) {
436 for (int c = 1; c < 8; ++c)
437 out_plane[r * 8 + c] += in_plane[(r + 1) * 8 + (c - 1)] * wSW;
438 }
439
440 const float wSE = wk[8];
441 for (int r = 0; r < 7; ++r) {
443 for (int c = 0; c < 7; ++c)
444 out_plane[r * 8 + c] += in_plane[(r + 1) * 8 + (c + 1)] * wSE;
445 }
446}
447
457static inline void conv3x3_single_out_relu(const float *HOT_RESTRICT in,
458 const float *HOT_RESTRICT w, float b,
459 float *HOT_RESTRICT out, int ic) {
461 for (int sq = 0; sq < 64; ++sq)
462 out[sq] = b;
463
464 for (int i = 0; i < ic; ++i) {
465 const float *wk = &w[(size_t)i * 9];
466 const float *in_plane = &in[(size_t)i * 64];
467 conv3x3_accumulate_plane(out, in_plane, wk);
468 }
469
471 for (int sq = 0; sq < 64; ++sq)
472 out[sq] = std::max(0.0f, out[sq]);
473}
474
485static void conv3x3_relu(const float *HOT_RESTRICT in,
486 const float *HOT_RESTRICT w,
487 const float *HOT_RESTRICT b, float *HOT_RESTRICT out,
488 int ic, int oc) {
489 const float *aligned_in = assume_aligned_32(in);
490 const float *aligned_w = assume_aligned_32(w);
491 const float *aligned_b = assume_aligned_32(b);
492 float *aligned_out = assume_aligned_32(out);
493
494 for (int o = 0; o < oc; ++o) {
495 float *out_plane = &aligned_out[(size_t)o * 64];
496 conv3x3_single_out_relu(aligned_in, &aligned_w[(size_t)o * ic * 9],
497 aligned_b[o], out_plane, ic);
498 }
499}
500
501static inline void conv3x3_relu_bd16_dispatch(const float *HOT_RESTRICT in,
502 const float *HOT_RESTRICT w,
503 const float *HOT_RESTRICT b,
504 float *HOT_RESTRICT out, int ic) {
505 conv3x3_relu(in, w, b, out, ic, 16);
506}
507
508#if defined(__ARM_NEON)
509static inline float32x4_t neon_fma_n(float32x4_t acc, float32x4_t x, float w) {
510#if defined(__aarch64__)
511 return vfmaq_n_f32(acc, x, w);
512#else
513 return vmlaq_n_f32(acc, x, w);
514#endif
515}
516
517static inline void conv3x3_row_accumulate_4oc_l0(
518 const float *w0, const float *w1, const float *w2, const float *w3,
519 const float *in_row, const float32x4_t &vzero, float32x4_t &a00,
520 float32x4_t &a01, float32x4_t &a10, float32x4_t &a11, float32x4_t &a20,
521 float32x4_t &a21, float32x4_t &a30, float32x4_t &a31, int kr) {
522 const float32x4_t v_row_0 = vld1q_f32(in_row);
523 const float32x4_t v_row_1 = vld1q_f32(in_row + 4);
524 const float32x4_t v_l0 = vextq_f32(vzero, v_row_0, 3);
525 const float32x4_t v_l1 = vextq_f32(v_row_0, v_row_1, 3);
526 const float32x4_t v_r0 = vextq_f32(v_row_0, v_row_1, 1);
527 const float32x4_t v_r1 = vextq_f32(v_row_1, vzero, 1);
528
529 const int base = kr * 3;
530 const float wl00 = w0[base + 0], wc00 = w0[base + 1], wr00 = w0[base + 2];
531 const float wl10 = w1[base + 0], wc10 = w1[base + 1], wr10 = w1[base + 2];
532 const float wl20 = w2[base + 0], wc20 = w2[base + 1], wr20 = w2[base + 2];
533 const float wl30 = w3[base + 0], wc30 = w3[base + 1], wr30 = w3[base + 2];
534
535 a00 = neon_fma_n(a00, v_l0, wl00);
536 a01 = neon_fma_n(a01, v_l1, wl00);
537 a00 = neon_fma_n(a00, v_row_0, wc00);
538 a01 = neon_fma_n(a01, v_row_1, wc00);
539 a00 = neon_fma_n(a00, v_r0, wr00);
540 a01 = neon_fma_n(a01, v_r1, wr00);
541
542 a10 = neon_fma_n(a10, v_l0, wl10);
543 a11 = neon_fma_n(a11, v_l1, wl10);
544 a10 = neon_fma_n(a10, v_row_0, wc10);
545 a11 = neon_fma_n(a11, v_row_1, wc10);
546 a10 = neon_fma_n(a10, v_r0, wr10);
547 a11 = neon_fma_n(a11, v_r1, wr10);
548
549 a20 = neon_fma_n(a20, v_l0, wl20);
550 a21 = neon_fma_n(a21, v_l1, wl20);
551 a20 = neon_fma_n(a20, v_row_0, wc20);
552 a21 = neon_fma_n(a21, v_row_1, wc20);
553 a20 = neon_fma_n(a20, v_r0, wr20);
554 a21 = neon_fma_n(a21, v_r1, wr20);
555
556 a30 = neon_fma_n(a30, v_l0, wl30);
557 a31 = neon_fma_n(a31, v_l1, wl30);
558 a30 = neon_fma_n(a30, v_row_0, wc30);
559 a31 = neon_fma_n(a31, v_row_1, wc30);
560 a30 = neon_fma_n(a30, v_r0, wr30);
561 a31 = neon_fma_n(a31, v_r1, wr30);
562}
563
564static inline void conv3x3_relu_bd16_neon(const float *in, const float *w,
565 const float *b, float *out,
566 int ic_total) {
567 const float32x4_t vzero = vdupq_n_f32(0.0f);
568
569 for (int oc = 0; oc < 16; oc += 4) {
570 for (int r = 0; r < 8; ++r) {
571 float32x4_t a00 = vdupq_n_f32(b[oc + 0]);
572 float32x4_t a01 = vdupq_n_f32(b[oc + 0]);
573 float32x4_t a10 = vdupq_n_f32(b[oc + 1]);
574 float32x4_t a11 = vdupq_n_f32(b[oc + 1]);
575 float32x4_t a20 = vdupq_n_f32(b[oc + 2]);
576 float32x4_t a21 = vdupq_n_f32(b[oc + 2]);
577 float32x4_t a30 = vdupq_n_f32(b[oc + 3]);
578 float32x4_t a31 = vdupq_n_f32(b[oc + 3]);
579
580 for (int ic = 0; ic < ic_total; ++ic) {
581 const float *in_plane = &in[(size_t)ic * 64];
582 const float *w0 = &w[((size_t)(oc + 0) * ic_total + (size_t)ic) * 9];
583 const float *w1 = &w[((size_t)(oc + 1) * ic_total + (size_t)ic) * 9];
584 const float *w2 = &w[((size_t)(oc + 2) * ic_total + (size_t)ic) * 9];
585 const float *w3 = &w[((size_t)(oc + 3) * ic_total + (size_t)ic) * 9];
586
587 if (r == 0) {
588 conv3x3_row_accumulate_4oc_l0(w0, w1, w2, w3, &in_plane[0], vzero,
589 a00, a01, a10, a11, a20, a21, a30, a31,
590 1);
591 conv3x3_row_accumulate_4oc_l0(w0, w1, w2, w3, &in_plane[8], vzero,
592 a00, a01, a10, a11, a20, a21, a30, a31,
593 2);
594 } else if (r == 7) {
595 conv3x3_row_accumulate_4oc_l0(w0, w1, w2, w3, &in_plane[48], vzero,
596 a00, a01, a10, a11, a20, a21, a30, a31,
597 0);
598 conv3x3_row_accumulate_4oc_l0(w0, w1, w2, w3, &in_plane[56], vzero,
599 a00, a01, a10, a11, a20, a21, a30, a31,
600 1);
601 } else {
602 conv3x3_row_accumulate_4oc_l0(
603 w0, w1, w2, w3, &in_plane[(size_t)(r - 1) * 8], vzero, a00, a01,
604 a10, a11, a20, a21, a30, a31, 0);
605 conv3x3_row_accumulate_4oc_l0(w0, w1, w2, w3,
606 &in_plane[(size_t)r * 8], vzero, a00,
607 a01, a10, a11, a20, a21, a30, a31, 1);
608 conv3x3_row_accumulate_4oc_l0(
609 w0, w1, w2, w3, &in_plane[(size_t)(r + 1) * 8], vzero, a00, a01,
610 a10, a11, a20, a21, a30, a31, 2);
611 }
612 }
613
614 const size_t row_base = (size_t)r * 8;
615 vst1q_f32(&out[(size_t)(oc + 0) * 64 + row_base + 0],
616 vmaxq_f32(vzero, a00));
617 vst1q_f32(&out[(size_t)(oc + 0) * 64 + row_base + 4],
618 vmaxq_f32(vzero, a01));
619 vst1q_f32(&out[(size_t)(oc + 1) * 64 + row_base + 0],
620 vmaxq_f32(vzero, a10));
621 vst1q_f32(&out[(size_t)(oc + 1) * 64 + row_base + 4],
622 vmaxq_f32(vzero, a11));
623 vst1q_f32(&out[(size_t)(oc + 2) * 64 + row_base + 0],
624 vmaxq_f32(vzero, a20));
625 vst1q_f32(&out[(size_t)(oc + 2) * 64 + row_base + 4],
626 vmaxq_f32(vzero, a21));
627 vst1q_f32(&out[(size_t)(oc + 3) * 64 + row_base + 0],
628 vmaxq_f32(vzero, a30));
629 vst1q_f32(&out[(size_t)(oc + 3) * 64 + row_base + 4],
630 vmaxq_f32(vzero, a31));
631 }
632 }
633}
634#endif
635
649static inline void depthwise_conv3x3_relu(const float *HOT_RESTRICT in,
650 const float *HOT_RESTRICT w,
651 const float *HOT_RESTRICT b,
652 float *HOT_RESTRICT out,
653 int channels) {
654 for (int c = 0; c < channels; ++c) {
655 const float *in_plane = &in[(size_t)c * 64];
656 float *out_plane = &out[(size_t)c * 64];
657 const float *wk = &w[(size_t)c * 9];
658
660 for (int sq = 0; sq < 64; ++sq)
661 out_plane[sq] = b[c];
662
663 conv3x3_accumulate_plane(out_plane, in_plane, wk);
664
666 for (int sq = 0; sq < 64; ++sq)
667 out_plane[sq] = std::max(0.0f, out_plane[sq]);
668 }
669}
670
685static void conv1x1_relu(const float *HOT_RESTRICT in,
686 const float *HOT_RESTRICT w,
687 const float *HOT_RESTRICT b, float *HOT_RESTRICT out,
688 int ic, int oc) {
689 const float *aligned_in = assume_aligned_32(in);
690 const float *aligned_w = assume_aligned_32(w);
691 const float *aligned_b = assume_aligned_32(b);
692 float *aligned_out = assume_aligned_32(out);
693
694 // 1) Initialize output with bias.
695 for (int o = 0; o < oc; ++o) {
696 const float bias = aligned_b[o];
697 float *out_plane = &aligned_out[(size_t)o * 64];
699 for (int sq = 0; sq < 64; ++sq)
700 out_plane[sq] = bias;
701 }
702
703 // 2) Accumulate 1x1 products.
704 for (int o = 0; o < oc; ++o) {
705 float *out_plane = &aligned_out[(size_t)o * 64];
706 for (int i = 0; i < ic; ++i) {
707 const float weight = aligned_w[(size_t)o * ic + i];
708 const float *in_plane = &aligned_in[(size_t)i * 64];
710 for (int sq = 0; sq < 64; ++sq)
711 out_plane[sq] += in_plane[sq] * weight;
712 }
713 }
714
715 // 3) Apply ReLU.
716 for (int o = 0; o < oc; ++o) {
717 float *out_plane = &aligned_out[(size_t)o * 64];
719 for (int sq = 0; sq < 64; ++sq)
720 out_plane[sq] = std::max(0.0f, out_plane[sq]);
721 }
722}
723
724#if defined(__ARM_NEON)
725static inline void conv1x1_bd16_relu_neon(const float *HOT_RESTRICT in,
726 const float *HOT_RESTRICT w,
727 const float *HOT_RESTRICT b,
728 float *HOT_RESTRICT out) {
729 const float32x4_t vzero = vdupq_n_f32(0.0f);
730
731 for (int oc = 0; oc < 16; oc += 4) {
732 const float *w0 = &w[(size_t)(oc + 0) * 16];
733 const float *w1 = &w[(size_t)(oc + 1) * 16];
734 const float *w2 = &w[(size_t)(oc + 2) * 16];
735 const float *w3 = &w[(size_t)(oc + 3) * 16];
736
737 for (int sq = 0; sq < 64; sq += 4) {
738 float32x4_t acc0 = vdupq_n_f32(b[oc + 0]);
739 float32x4_t acc1 = vdupq_n_f32(b[oc + 1]);
740 float32x4_t acc2 = vdupq_n_f32(b[oc + 2]);
741 float32x4_t acc3 = vdupq_n_f32(b[oc + 3]);
742
743 for (int ic = 0; ic < 16; ++ic) {
744 const float32x4_t vin = vld1q_f32(&in[(size_t)ic * 64 + (size_t)sq]);
745#if defined(__aarch64__)
746 acc0 = vfmaq_n_f32(acc0, vin, w0[ic]);
747 acc1 = vfmaq_n_f32(acc1, vin, w1[ic]);
748 acc2 = vfmaq_n_f32(acc2, vin, w2[ic]);
749 acc3 = vfmaq_n_f32(acc3, vin, w3[ic]);
750#else
751 acc0 = vmlaq_n_f32(acc0, vin, w0[ic]);
752 acc1 = vmlaq_n_f32(acc1, vin, w1[ic]);
753 acc2 = vmlaq_n_f32(acc2, vin, w2[ic]);
754 acc3 = vmlaq_n_f32(acc3, vin, w3[ic]);
755#endif
756 }
757
758 vst1q_f32(&out[(size_t)(oc + 0) * 64 + (size_t)sq],
759 vmaxq_f32(vzero, acc0));
760 vst1q_f32(&out[(size_t)(oc + 1) * 64 + (size_t)sq],
761 vmaxq_f32(vzero, acc1));
762 vst1q_f32(&out[(size_t)(oc + 2) * 64 + (size_t)sq],
763 vmaxq_f32(vzero, acc2));
764 vst1q_f32(&out[(size_t)(oc + 3) * 64 + (size_t)sq],
765 vmaxq_f32(vzero, acc3));
766 }
767 }
768}
769#endif
770
771#if defined(__AVX2__)
772static inline __m256 avx2_fmadd_ps(__m256 a, __m256 b, __m256 c) {
773#if defined(__FMA__)
774 return _mm256_fmadd_ps(a, b, c);
775#else
776 return _mm256_add_ps(_mm256_mul_ps(a, b), c);
777#endif
778}
779
780static inline void conv1x1_bd16_relu_avx2(const float *HOT_RESTRICT in,
781 const float *HOT_RESTRICT w,
782 const float *HOT_RESTRICT b,
783 float *HOT_RESTRICT out) {
784 const __m256 vzero = _mm256_setzero_ps();
785 assert((reinterpret_cast<uintptr_t>(in) & 31u) == 0u);
786 assert((reinterpret_cast<uintptr_t>(out) & 31u) == 0u);
787
788 for (int oc = 0; oc < 16; oc += 4) {
789 const float *w0 = &w[(size_t)(oc + 0) * 16];
790 const float *w1 = &w[(size_t)(oc + 1) * 16];
791 const float *w2 = &w[(size_t)(oc + 2) * 16];
792 const float *w3 = &w[(size_t)(oc + 3) * 16];
793
794 for (int sq = 0; sq < 64; sq += 8) {
795 __m256 acc0 = _mm256_set1_ps(b[oc + 0]);
796 __m256 acc1 = _mm256_set1_ps(b[oc + 1]);
797 __m256 acc2 = _mm256_set1_ps(b[oc + 2]);
798 __m256 acc3 = _mm256_set1_ps(b[oc + 3]);
799
800 for (int ic = 0; ic < 16; ++ic) {
801 const __m256 vin = _mm256_load_ps(&in[(size_t)ic * 64 + (size_t)sq]);
802 acc0 = avx2_fmadd_ps(vin, _mm256_set1_ps(w0[ic]), acc0);
803 acc1 = avx2_fmadd_ps(vin, _mm256_set1_ps(w1[ic]), acc1);
804 acc2 = avx2_fmadd_ps(vin, _mm256_set1_ps(w2[ic]), acc2);
805 acc3 = avx2_fmadd_ps(vin, _mm256_set1_ps(w3[ic]), acc3);
806 }
807
808 _mm256_store_ps(&out[(size_t)(oc + 0) * 64 + (size_t)sq],
809 _mm256_max_ps(vzero, acc0));
810 _mm256_store_ps(&out[(size_t)(oc + 1) * 64 + (size_t)sq],
811 _mm256_max_ps(vzero, acc1));
812 _mm256_store_ps(&out[(size_t)(oc + 2) * 64 + (size_t)sq],
813 _mm256_max_ps(vzero, acc2));
814 _mm256_store_ps(&out[(size_t)(oc + 3) * 64 + (size_t)sq],
815 _mm256_max_ps(vzero, acc3));
816 }
817 }
818}
819#endif
820
832
833static constexpr int kPool2x2Regions = 16;
834
835static inline const char *expert_pool_mode_name(ExpertPoolMode mode) {
836 switch (mode) {
838 return "flat";
840 return "gap";
842 return "pool2avg";
844 return "pool2max";
845 }
846 return "flat";
847}
848
849static inline int pool2x2_region_from_sq(int sq) {
850 const int r = sq >> 3;
851 const int c = sq & 7;
852 return ((r >> 1) << 2) | (c >> 1);
853}
854
855static inline void pool2x2_region_base(int region, int &r0, int &c0) {
856 r0 = (region >> 2) * 2;
857 c0 = (region & 3) * 2;
858}
859
880
882 std::string name;
883 double fullUs = 0.0;
884 long long fullPlies = 0;
885 double incUs = 0.0;
886 long long incPlies = 0;
887 long long incDirtyHeads = 0;
888 long long incEvents = 0;
889
890 // Pre-generation costs (outside timed inference loop).
891 double prepPlayUs = 0.0;
892 double prepFeatureUs = 0.0;
893 double prepDirtyUs = 0.0;
894
895 // Timed inference breakdown.
896 double fullRouteUs = 0.0;
897 double fullUpdateUs = 0.0;
898 double fullExpertUs = 0.0;
899 double incRouteUs = 0.0;
900 double incUpdateUs = 0.0;
901 double incExpertUs = 0.0;
902
903 // Full rebuild internal breakdown.
905 double fullMixerAccumUs = 0.0;
906 double fullMixerReluUs = 0.0;
907 double fullExpertCacheUs = 0.0;
908
909 // Incremental update internal breakdown.
910 double incBranchDeltaUs = 0.0;
911 double incBypassDeltaUs = 0.0;
912 double incGlobalReluUs = 0.0;
914 double incHiddenDeltaUs = 0.0;
916
917 double fullNps() const {
918 return fullUs > 0 ? (fullPlies * 1e6 / fullUs) : 0.0;
919 }
920 double incNps() const { return incUs > 0 ? (incPlies * 1e6 / incUs) : 0.0; }
921 double avgDirtyHeads() const {
922 return incEvents > 0 ? (double)incDirtyHeads / (double)incEvents : 0.0;
923 }
924};
925
927 int ic = 0;
928 int oc = 0;
929 alignas(32) std::array<float, (size_t)5 * NET_BRANCH_DIM * 9> w{};
930 alignas(32) std::array<float, NET_BRANCH_DIM> b{};
931};
932
938
939struct Expert {
940 alignas(32) std::array<float, (size_t)NET_EXPERT_BOTTLENECK *
941 NET_MIXER_OUT> wConv{}; // [ebo][nf]
942 alignas(32) std::array<float, NET_EXPERT_BOTTLENECK> bConv{}; // [ebo]
943
944 alignas(32)
945 std::array<float, (size_t)NET_EXPERT_HIDDEN *
946 NET_EXPERT_BOTTLENECK> wHG{}; // [eh][ebo] (GAP
947 // pooled hidden path)
948 alignas(32) std::array<
949 float, (size_t)NET_EXPERT_BOTTLENECK *
950 NET_EXPERT_HIDDEN> wHGT{}; // [ebo][eh] (transposed wHG)
951
952 alignas(
953 32) std::array<float, (size_t)NET_EXPERT_HIDDEN * NET_EXPERT_BOTTLENECK *
954 kPool2x2Regions> wH16{}; // [eh][ebo*16] (2x2
955 // pooled hidden path)
956 alignas(32) std::array<
957 float, (size_t)NET_EXPERT_BOTTLENECK * kPool2x2Regions *
958 NET_EXPERT_HIDDEN> wH16T{}; // [ebo*16][eh] (transposed wH16)
959
960 alignas(
961 32) std::array<float, (size_t)NET_EXPERT_HIDDEN * NET_EXPERT_BOTTLENECK *
962 64> wH{}; // [eh][ebo*64]
963 alignas(32) std::array<float, (size_t)NET_EXPERT_BOTTLENECK * 64 *
964 NET_EXPERT_HIDDEN> wHT{}; // [ebo*64][eh]
965 // (transposed wH)
966 alignas(32) std::array<float, NET_EXPERT_HIDDEN> bH{}; // [eh]
967
968 alignas(
969 32) std::array<float, (size_t)3 * NET_EXPERT_HIDDEN> wWdl{}; // [3][eh]
970 alignas(32) std::array<float, 3> bWdl{}; // [3]
971};
972
987 static constexpr int bd = NET_BRANCH_DIM;
988 static constexpr int nf = NET_MIXER_OUT;
989 static constexpr int nBypass = NET_BYPASS;
990 static constexpr int nGlobals = NET_GLOBALS;
991 static constexpr int nExperts = NET_EXPERTS;
992 static constexpr int ebo = NET_EXPERT_BOTTLENECK;
993 static constexpr int eh = NET_EXPERT_HIDDEN;
994
995 std::array<Branch, 12> branches{};
996
997 alignas(64) std::array<float, (size_t)12 * NET_MIXER_OUT *
998 NET_BRANCH_DIM> mixerWBr{}; // [12][nf][bd]
999 alignas(64) std::array<float, (size_t)NET_BYPASS *
1000 NET_MIXER_OUT> mixerWBp{}; // [nBypass][nf]
1001 alignas(64) std::array<float, NET_MIXER_OUT> mixerB{}; // [nf]
1002
1003 alignas(64) std::array<float, (size_t)NET_MIXER_OUT *
1004 NET_GLOBALS> globalW{}; // [nf][nGlobals]
1005 alignas(64) std::array<float, NET_MIXER_OUT> globalB{}; // [nf]
1006
1007 alignas(64) std::array<
1008 float, (size_t)NET_EXPERTS * NET_GLOBALS> gateW{}; // [nExperts][nGlobals]
1009 alignas(64) std::array<float, NET_EXPERTS> gateB{}; // [nExperts]
1010
1011 std::array<Expert, NET_EXPERTS> experts{};
1012
1013 void init_architecture(int branchConvLayers) {
1014 auto init_branch = [&](int ic) {
1015 Branch br;
1016 br.l0.ic = ic;
1017 br.l0.oc = bd;
1018
1019 if (branchConvLayers >= 2) {
1020 br.l1.ic = bd;
1021 br.l1.oc = bd;
1022 }
1023
1024 if (branchConvLayers >= 3) {
1025 br.l2.ic = bd;
1026 br.l2.oc = bd;
1027 }
1028 return br;
1029 };
1030
1031 for (int b = 0; b < 12; ++b)
1032 branches[(size_t)b] = init_branch(active_channels_for_branch(b));
1033 }
1034};
1035
1053 int nThreads = 1;
1058 bool routeSlowGlobals = false;
1059 static constexpr int bd = NET_BRANCH_DIM;
1060 static constexpr int nf = NET_MIXER_OUT;
1061 static constexpr int nBypass = NET_BYPASS;
1062 static constexpr int nGlobals = NET_GLOBALS;
1063 static constexpr int nExperts = NET_EXPERTS;
1064 static constexpr int ebo = NET_EXPERT_BOTTLENECK;
1065 static constexpr int eh = NET_EXPERT_HIDDEN;
1066
1067 const SharedMoEWeights *weights = nullptr;
1068 std::shared_ptr<SharedMoEWeights> ownedWeights{};
1069
1070 std::unique_ptr<PersistentThreadPool> threadPool;
1071
1072 // Persistent caches for incremental path.
1073 alignas(64) std::array<float, (size_t)12 * kMaxBranchDim * 64> branchCache{};
1074 alignas(64) std::array<float, (size_t)kMaxMixerOut * 64> mixerLinearAccum{};
1075 alignas(64) std::array<float, (size_t)kMaxMixerOut * 64> mixerReluCache{};
1076
1077 alignas(64) std::array<std::array<float, (size_t)kMaxExpertBottleneck * 64>,
1079 alignas(64) std::array<std::array<float, (size_t)kMaxExpertBottleneck * 64>,
1081 std::array<uint8_t, kMaxExperts> exValid{};
1082 alignas(64)
1083 std::array<std::array<float, kMaxExpertHidden>, kMaxExperts> hiddenAcc{};
1084 alignas(64) std::array<std::array<float, kMaxExpertBottleneck>,
1086 alignas(64) std::array<
1087 std::array<float, (size_t)kMaxExpertBottleneck * kPool2x2Regions>,
1089 alignas(64) std::array<float, kMaxGlobals> oldGlobalV{};
1090
1091 // Scratch buffers (fixed-size, no heap allocations).
1092 alignas(64) std::array<float, (size_t)kMaxBranchDim * 64> scratchT0{};
1093 alignas(64) std::array<float, (size_t)kMaxBranchDim * 64> scratchT1{};
1094 alignas(64) std::array<float, (size_t)kMaxBranchDim * 64> scratchNewBranch{};
1095 alignas(64) std::array<std::array<float, (size_t)kMaxBranchDim * 64>,
1097 alignas(64) std::array<std::array<float, (size_t)kMaxBranchDim * 64>,
1099 alignas(64)
1100 std::array<float, (size_t)12 * kMaxBranchDim * 64> scratchDirtyBranches{};
1101 alignas(
1102 64) std::array<float, (size_t)kMaxBranchDim * 64> scratchBranchDelta{};
1103 alignas(64) std::array<float, (size_t)kMaxBypass * 64> scratchBypassDelta{};
1104 alignas(64) std::array<float, kMaxMixerOut> scratchGproj{};
1105 alignas(64) std::array<float, (size_t)kMaxMixerOut * 64> scratchDeltaRelu{};
1106 alignas(64)
1107 std::array<float, (size_t)kMaxExpertBottleneck * 64> scratchFlatDelta{};
1108 alignas(64) std::array<std::array<float, (size_t)kMaxExpertBottleneck * 64>,
1110 alignas(64) std::array<float, kMaxExpertHidden> scratchHidden{};
1111
1114 double fullMixerAccumUs = 0.0;
1115 double fullMixerReluUs = 0.0;
1116 double fullExpertCacheUs = 0.0;
1117
1118 double incBranchDeltaUs = 0.0;
1119 double incBypassDeltaUs = 0.0;
1120 double incGlobalReluUs = 0.0;
1122 double incHiddenDeltaUs = 0.0;
1124 };
1125
1127
1128 bool initialized = false;
1129
1131 if (!weights)
1132 throw std::runtime_error("MoE weights pointer is null");
1133 return *weights;
1134 }
1135
1137 if (!ownedWeights)
1138 ownedWeights = std::make_shared<SharedMoEWeights>();
1139 weights = ownedWeights.get();
1140 return *ownedWeights;
1141 }
1142
1144 std::fill(branchCache.begin(), branchCache.end(), 0.0f);
1145 std::fill(mixerLinearAccum.begin(), mixerLinearAccum.end(), 0.0f);
1146 std::fill(mixerReluCache.begin(), mixerReluCache.end(), 0.0f);
1147 for (auto &v : exPreAccum)
1148 std::fill(v.begin(), v.end(), 0.0f);
1149 for (auto &v : exReluCache)
1150 std::fill(v.begin(), v.end(), 0.0f);
1151 std::fill(exValid.begin(), exValid.end(), 0);
1152 for (auto &v : hiddenAcc)
1153 std::fill(v.begin(), v.end(), 0.0f);
1154 for (auto &v : exGapCache)
1155 std::fill(v.begin(), v.end(), 0.0f);
1156 for (auto &v : exPool16Cache)
1157 std::fill(v.begin(), v.end(), 0.0f);
1158
1159 std::fill(scratchT0.begin(), scratchT0.end(), 0.0f);
1160 std::fill(scratchT1.begin(), scratchT1.end(), 0.0f);
1161 std::fill(scratchNewBranch.begin(), scratchNewBranch.end(), 0.0f);
1162 std::fill(scratchDirtyBranches.begin(), scratchDirtyBranches.end(), 0.0f);
1163 std::fill(scratchBranchDelta.begin(), scratchBranchDelta.end(), 0.0f);
1164 std::fill(scratchBypassDelta.begin(), scratchBypassDelta.end(), 0.0f);
1165 std::fill(scratchGproj.begin(), scratchGproj.end(), 0.0f);
1166 std::fill(scratchDeltaRelu.begin(), scratchDeltaRelu.end(), 0.0f);
1167 std::fill(scratchFlatDelta.begin(), scratchFlatDelta.end(), 0.0f);
1168 std::fill(scratchHidden.begin(), scratchHidden.end(), 0.0f);
1169 std::fill(oldGlobalV.begin(), oldGlobalV.end(), 0.0f);
1170
1171 if (nThreads > 1)
1172 threadPool = std::make_unique<PersistentThreadPool>(nThreads - 1);
1173 else
1174 threadPool.reset();
1175
1176 initialized = false;
1177 }
1178
1193
1195
1196 template <typename Fn>
1197 void parallel_for_indices(int n, int min_parallel_n, Fn &&fn) {
1198 if (n <= 0)
1199 return;
1200 if (!threadPool || nThreads <= 1 || n < min_parallel_n) {
1201 for (int i = 0; i < n; ++i)
1202 fn(i);
1203 return;
1204 }
1205 threadPool->parallel_for(n, std::function<void(int)>(std::forward<Fn>(fn)));
1206 }
1207
1208 template <typename Fn> void parallel_for_indices(int n, Fn &&fn) {
1209 parallel_for_indices(n, 2, std::forward<Fn>(fn));
1210 }
1211
1212 long long total_weights() const {
1213 const auto &w = shared_weights();
1214 long long total = 0;
1215 auto add_vec = [&](const auto &v) { total += (long long)v.size(); };
1216
1217 for (const auto &br : w.branches) {
1218 add_vec(br.l0.w);
1219 add_vec(br.l0.b);
1220 add_vec(br.l1.w);
1221 add_vec(br.l1.b);
1222 add_vec(br.l2.w);
1223 add_vec(br.l2.b);
1224 }
1225
1226 add_vec(w.mixerWBr);
1227 add_vec(w.mixerWBp);
1228 add_vec(w.mixerB);
1229 add_vec(w.globalW);
1230 add_vec(w.globalB);
1231 add_vec(w.gateW);
1232 add_vec(w.gateB);
1233
1234 for (const auto &ex : w.experts) {
1235 add_vec(ex.wConv);
1236 add_vec(ex.bConv);
1238 add_vec(ex.wHG);
1240 add_vec(ex.wH);
1241 else
1242 add_vec(ex.wH16);
1243 add_vec(ex.bH);
1244 add_vec(ex.wWdl);
1245 add_vec(ex.bWdl);
1246 }
1247
1248 return total;
1249 }
1250
1251 long long single_expert_weights() const {
1252 const auto &w = shared_weights();
1253 if (w.experts.empty())
1254 return 0;
1255
1256 long long total = 0;
1257 const Expert &ex = w.experts.front();
1258 auto add_vec = [&](const auto &v) { total += (long long)v.size(); };
1259
1260 add_vec(ex.wConv);
1261 add_vec(ex.bConv);
1263 add_vec(ex.wHG);
1265 add_vec(ex.wH);
1266 else
1267 add_vec(ex.wH16);
1268 add_vec(ex.bH);
1269 add_vec(ex.wWdl);
1270 add_vec(ex.bWdl);
1271
1272 return total;
1273 }
1274
1275 long long experts_total_weights() const {
1276 const auto &w = shared_weights();
1277 long long total = 0;
1278 auto add_vec = [&](const auto &v) { total += (long long)v.size(); };
1279
1280 for (const auto &ex : w.experts) {
1281 add_vec(ex.wConv);
1282 add_vec(ex.bConv);
1284 add_vec(ex.wHG);
1286 add_vec(ex.wH);
1287 else
1288 add_vec(ex.wH16);
1289 add_vec(ex.bH);
1290 add_vec(ex.wWdl);
1291 add_vec(ex.bWdl);
1292 }
1293
1294 return total;
1295 }
1296
1297 long long backbone_weights() const {
1299 }
1300
1301 long long runtime_topk_weights(int topk) const {
1302 const int k = std::clamp(topk, 0, nExperts);
1303 return backbone_weights() + (long long)k * single_expert_weights();
1304 }
1305
1307 if (cfg.branchDim != bd || cfg.mixerOut != nf || cfg.nBypass != nBypass ||
1308 cfg.nGlobals != nGlobals || cfg.nExperts != nExperts ||
1309 cfg.expertBottleneck != ebo || cfg.expertHidden != eh) {
1310 throw std::runtime_error(
1311 "Weights dimensions do not match fixed native architecture");
1312 }
1313 }
1314
1315 void init(const SharedMoEWeights *shared, const BenchConfig &cfg) {
1317
1318 if (!shared)
1319 throw std::runtime_error("Null shared MoE weights");
1320
1322 nThreads = std::max(1, cfg.nThreads);
1325 denseDirtySqThreshold = std::clamp(cfg.denseDirtySqThreshold, 1, 64);
1328
1329 const bool keep_owned = ownedWeights && (shared == ownedWeights.get());
1330 if (!keep_owned)
1331 ownedWeights.reset();
1332 weights = shared;
1334 }
1335
1336 void init(const BenchConfig &cfg) {
1337 auto &w = mutable_owned_weights();
1338 w.init_architecture(cfg.branchConvLayers);
1339 init(&w, cfg);
1340 }
1341
1342 void fill_random(unsigned seed) {
1343 std::mt19937 rng(seed);
1344 std::normal_distribution<float> nd(0.0f, 0.05f);
1345
1346 auto &w = mutable_owned_weights();
1347 w.init_architecture(branchConvLayers);
1348
1349 auto fill = [&](auto &v) {
1350 for (float &x : v)
1351 x = nd(rng);
1352 };
1353
1354 for (auto &br : w.branches) {
1355 fill(br.l0.w);
1356 fill(br.l0.b);
1357 if (branchConvLayers >= 2) {
1358 fill(br.l1.w);
1359 fill(br.l1.b);
1360 }
1361 if (branchConvLayers >= 3) {
1362 fill(br.l2.w);
1363 fill(br.l2.b);
1364 }
1365 }
1366
1367 fill(w.mixerWBr);
1368 fill(w.mixerWBp);
1369 fill(w.mixerB);
1370 fill(w.globalW);
1371 fill(w.globalB);
1372 fill(w.gateW);
1373 fill(w.gateB);
1374
1375 for (auto &ex : w.experts) {
1376 fill(ex.wConv);
1377 fill(ex.bConv);
1379 fill(ex.wHG);
1380 for (int h = 0; h < eh; ++h) {
1381 for (int f = 0; f < ebo; ++f) {
1382 ex.wHGT[(size_t)f * eh + h] = ex.wHG[(size_t)h * ebo + f];
1383 }
1384 }
1385 } else if (expertPoolMode == ExpertPoolMode::Flat) {
1386 fill(ex.wH);
1387 for (int h = 0; h < eh; ++h) {
1388 for (int f = 0; f < ebo * 64; ++f) {
1389 ex.wHT[(size_t)f * eh + h] = ex.wH[(size_t)h * (ebo * 64) + f];
1390 }
1391 }
1392 } else {
1393 fill(ex.wH16);
1394 for (int h = 0; h < eh; ++h) {
1395 for (int f = 0; f < ebo * kPool2x2Regions; ++f) {
1396 ex.wH16T[(size_t)f * eh + h] =
1397 ex.wH16[(size_t)h * (ebo * kPool2x2Regions) + f];
1398 }
1399 }
1400 }
1401 fill(ex.bH);
1402 fill(ex.wWdl);
1403 fill(ex.bWdl);
1404 }
1405 }
1406
1408 const float *HOT_RESTRICT in_planes,
1409 float *HOT_RESTRICT out,
1410 float *HOT_RESTRICT mid_plane,
1411 float *HOT_RESTRICT l1_accum) {
1412#if defined(__ARM_NEON)
1413 conv3x3_relu_bd16_neon(in_planes, br.l0.w.data(), br.l0.b.data(), mid_plane,
1414 br.l0.ic);
1415 depthwise_conv3x3_relu(mid_plane, br.l1.w.data(), br.l1.b.data(), l1_accum,
1416 16);
1417 conv1x1_bd16_relu_neon(l1_accum, br.l2.w.data(), br.l2.b.data(), out);
1418#elif defined(__AVX2__)
1419 conv3x3_relu_bd16_dispatch(in_planes, br.l0.w.data(), br.l0.b.data(),
1420 mid_plane, br.l0.ic);
1421 depthwise_conv3x3_relu(mid_plane, br.l1.w.data(), br.l1.b.data(), l1_accum,
1422 16);
1423 conv1x1_bd16_relu_avx2(l1_accum, br.l2.w.data(), br.l2.b.data(), out);
1424#else
1425 conv3x3_relu_bd16_dispatch(in_planes, br.l0.w.data(), br.l0.b.data(),
1426 mid_plane, br.l0.ic);
1427 depthwise_conv3x3_relu(mid_plane, br.l1.w.data(), br.l1.b.data(), l1_accum,
1428 16);
1429 conv1x1_relu(l1_accum, br.l2.w.data(), br.l2.b.data(), out, 16, 16);
1430#endif
1431 }
1432
1433 void branch_forward_with_scratch(int b, const float *in_planes, float *out,
1434 float *scratch0, float *scratch1) {
1435 const auto &br = shared_weights().branches[(size_t)b];
1436
1437 if (branchConvLayers == 1) {
1438 if (bd == 16)
1439 conv3x3_relu_bd16_dispatch(in_planes, br.l0.w.data(), br.l0.b.data(),
1440 out, br.l0.ic);
1441 else
1442 conv3x3_relu(in_planes, br.l0.w.data(), br.l0.b.data(), out, br.l0.ic,
1443 bd);
1444 return;
1445 }
1446
1447 if (branchConvLayers == 2) {
1448 if (bd == 16) {
1449 conv3x3_relu_bd16_dispatch(in_planes, br.l0.w.data(), br.l0.b.data(),
1450 scratch0, br.l0.ic);
1451 conv3x3_relu(scratch0, br.l1.w.data(), br.l1.b.data(), out, 16, 16);
1452 } else {
1453 conv3x3_relu(in_planes, br.l0.w.data(), br.l0.b.data(), scratch0,
1454 br.l0.ic, bd);
1455 conv3x3_relu(scratch0, br.l1.w.data(), br.l1.b.data(), out, bd, bd);
1456 }
1457 return;
1458 }
1459
1460 if (bd == 16) {
1461 branch_forward_bd16_fast(br, in_planes, out, scratch0, scratch1);
1462 return;
1463 }
1464 conv3x3_relu(in_planes, br.l0.w.data(), br.l0.b.data(), scratch0, br.l0.ic,
1465 bd);
1466 depthwise_conv3x3_relu(scratch0, br.l1.w.data(), br.l1.b.data(), scratch1,
1467 bd);
1468 conv1x1_relu(scratch1, br.l2.w.data(), br.l2.b.data(), out, bd, bd);
1469 }
1470
1471 void branch_forward(int b, const float *in_planes, float *out) {
1472 branch_forward_with_scratch(b, in_planes, out, scratchT0.data(),
1473 scratchT1.data());
1474 }
1475
1477 auto &hacc = hiddenAcc[(size_t)e];
1478 const auto &ex = shared_weights().experts[(size_t)e];
1479 const auto &flat = exReluCache[(size_t)e];
1480 const int flat_sz = ebo * 64;
1481
1482 for (int out = 0; out < eh; ++out)
1483 hacc[(size_t)out] = ex.bH[(size_t)out];
1484
1485 for (int i = 0; i < flat_sz; ++i) {
1486 const float fv = flat[(size_t)i];
1487 if (std::abs(fv) < 1e-7f)
1488 continue;
1489 const float *wt = &ex.wHT[(size_t)i * eh];
1490 for (int out = 0; out < eh; ++out)
1491 hacc[(size_t)out] += fv * wt[out];
1492 }
1493 }
1494
1496 auto &hacc = hiddenAcc[(size_t)e];
1497 auto &gap = exGapCache[(size_t)e];
1498 const auto &ex = shared_weights().experts[(size_t)e];
1499 const auto &flat = exReluCache[(size_t)e];
1500 constexpr float inv64 = 1.0f / 64.0f;
1501
1502 for (int bo = 0; bo < ebo; ++bo) {
1503 float s = 0.0f;
1504 const float *rbo = &flat[(size_t)bo * 64];
1506 for (int sq = 0; sq < 64; ++sq)
1507 s += rbo[sq];
1508 gap[(size_t)bo] = s * inv64;
1509 }
1510
1511 for (int out = 0; out < eh; ++out)
1512 hacc[(size_t)out] = ex.bH[(size_t)out];
1513
1514 for (int bo = 0; bo < ebo; ++bo) {
1515 const float gv = gap[(size_t)bo];
1516 if (std::abs(gv) < 1e-7f)
1517 continue;
1518 const float *wt = &ex.wHGT[(size_t)bo * eh];
1519 for (int out = 0; out < eh; ++out)
1520 hacc[(size_t)out] += gv * wt[out];
1521 }
1522 }
1523
1524 void rebuild_hidden_acc_from_pool2x2(int e, bool max_pool) {
1525 auto &hacc = hiddenAcc[(size_t)e];
1526 auto &pool = exPool16Cache[(size_t)e];
1527 const auto &ex = shared_weights().experts[(size_t)e];
1528 const auto &flat = exReluCache[(size_t)e];
1529
1530 for (int bo = 0; bo < ebo; ++bo) {
1531 const float *rbo = &flat[(size_t)bo * 64];
1532 for (int region = 0; region < kPool2x2Regions; ++region) {
1533 int r0 = 0, c0 = 0;
1534 pool2x2_region_base(region, r0, c0);
1535 const int sq0 = r0 * 8 + c0;
1536 const int sq1 = sq0 + 1;
1537 const int sq2 = sq0 + 8;
1538 const int sq3 = sq2 + 1;
1539
1540 float pv;
1541 if (max_pool) {
1542 pv = std::max(std::max(rbo[sq0], rbo[sq1]),
1543 std::max(rbo[sq2], rbo[sq3]));
1544 } else {
1545 pv = 0.25f * (rbo[sq0] + rbo[sq1] + rbo[sq2] + rbo[sq3]);
1546 }
1547 pool[(size_t)bo * kPool2x2Regions + region] = pv;
1548 }
1549 }
1550
1551 for (int out = 0; out < eh; ++out)
1552 hacc[(size_t)out] = ex.bH[(size_t)out];
1553
1554 const int pool_sz = ebo * kPool2x2Regions;
1555 for (int i = 0; i < pool_sz; ++i) {
1556 const float pv = pool[(size_t)i];
1557 if (std::abs(pv) < 1e-7f)
1558 continue;
1559 const float *wt = &ex.wH16T[(size_t)i * eh];
1560 for (int out = 0; out < eh; ++out)
1561 hacc[(size_t)out] += pv * wt[out];
1562 }
1563 }
1564
1565 inline float global_proj_at(int oc, const float *g) const {
1566 const auto &shared = shared_weights();
1567 float s = shared.globalB[(size_t)oc];
1568 const float *w = &shared.globalW[(size_t)oc * nGlobals];
1569 for (int i = 0; i < nGlobals; ++i)
1570 s += w[i] * g[i];
1571 return s;
1572 }
1573
1574 void top2_experts(const float *global, int &e0, int &e1, float &w0,
1575 float &w1) const {
1576 e0 = 0;
1577 e1 = 1;
1578 float s0 = std::numeric_limits<float>::lowest();
1579 float s1 = std::numeric_limits<float>::lowest();
1580 const auto &shared = shared_weights();
1581
1582 for (int e = 0; e < nExperts; ++e) {
1583 float s = shared.gateB[(size_t)e];
1584 const float *w = &shared.gateW[(size_t)e * nGlobals];
1585 if (routeSlowGlobals) {
1586 const int slow_n = std::min(nGlobals, 15);
1587 for (int i = 0; i < slow_n; ++i)
1588 s += w[i] * global[i];
1589 } else {
1590 for (int i = 0; i < nGlobals; ++i)
1591 s += w[i] * global[i];
1592 }
1593
1594 if (s > s0) {
1595 s1 = s0;
1596 e1 = e0;
1597 s0 = s;
1598 e0 = e;
1599 } else if (s > s1) {
1600 s1 = s;
1601 e1 = e;
1602 }
1603 }
1604
1605 const float m = std::max(s0, s1);
1606 const float p0 = std::exp(s0 - m);
1607 const float p1 = std::exp(s1 - m);
1608 const float z = p0 + p1;
1609 if (z > 1e-20f) {
1610 w0 = p0 / z;
1611 w1 = p1 / z;
1612 } else {
1613 w0 = 0.5f;
1614 w1 = 0.5f;
1615 }
1616 }
1617
1619 auto &pre = exPreAccum[(size_t)e];
1620 auto &rel = exReluCache[(size_t)e];
1621 const auto &ex = shared_weights().experts[(size_t)e];
1622 for (int bo = 0; bo < ebo; ++bo) {
1623 float *pbo = &pre[(size_t)bo * 64];
1624
1625 const float b = ex.bConv[(size_t)bo];
1626 for (int sq = 0; sq < 64; ++sq)
1627 pbo[sq] = b;
1628
1629 const float *w = &ex.wConv[(size_t)bo * nf];
1630 for (int oc = 0; oc < nf; ++oc) {
1631 const float wc = w[oc];
1632 const float *in_plane = &mixerReluCache[(size_t)oc * 64];
1633#pragma GCC ivdep
1634 for (int sq = 0; sq < 64; ++sq)
1635 pbo[sq] += in_plane[sq] * wc;
1636 }
1637
1638 float *rbo = &rel[(size_t)bo * 64];
1639 for (int sq = 0; sq < 64; ++sq)
1640 rbo[sq] = std::max(0.0f, pbo[sq]);
1641 }
1646 else
1649 exValid[(size_t)e] = 1;
1650 }
1651
1666 const int *active_experts, int active_count) {
1667 const auto &shared = shared_weights();
1668 const float *w_mixer_br = shared.mixerWBr.data();
1669 const float *w_mixer_bp = shared.mixerWBp.data();
1670 auto tBranch0 = Clock::now();
1671 // 1) Recompute and cache all branch outputs.
1672 parallel_for_indices(12, 2, [&](int b) {
1673 const float *pin = &inp.branches[b][0][0];
1674 float *bout = &branchCache[(size_t)b * bd * 64];
1675 branch_forward_with_scratch(b, pin, bout,
1676 scratchParallelBranch0[(size_t)b].data(),
1677 scratchParallelBranch1[(size_t)b].data());
1678 });
1679 auto tBranch1 = Clock::now();
1680 profile.fullBranchForwardUs += us(tBranch0, tBranch1);
1681
1682 // 2) Rebuild mixer linear accumulator from branches + bypass + bias.
1683 auto tMix0 = Clock::now();
1684 // Bias initialize.
1685 for (int oc = 0; oc < nf; ++oc) {
1686 float *macc = &mixerLinearAccum[(size_t)oc * 64];
1687 const float bias = shared.mixerB[(size_t)oc];
1688 for (int sq = 0; sq < 64; ++sq)
1689 macc[sq] = bias;
1690 }
1691
1692 // Branch+bypass accumulation using oc tiles to improve cache locality.
1693 const int ocTiles = (nf + kMixerOcTile - 1) / kMixerOcTile;
1694 parallel_for_indices(ocTiles, 2, [&](int tile) {
1695 const int oc0 = tile * kMixerOcTile;
1696 const int ocl = std::min(kMixerOcTile, nf - oc0);
1697 float *macc[kMixerOcTile];
1698 for (int t = 0; t < ocl; ++t)
1699 macc[t] = &mixerLinearAccum[(size_t)(oc0 + t) * 64];
1700
1701 for (int b = 0; b < 12; ++b) {
1702 const float *bo = &branchCache[(size_t)b * bd * 64];
1703 for (int c = 0; c < bd; ++c) {
1704 const float *in_plane = &bo[(size_t)c * 64];
1705 for (int t = 0; t < ocl; ++t) {
1706 const float wc = w_mixer_br[((size_t)b * nf + (oc0 + t)) * bd + c];
1707 simd_add_scaled(macc[t], in_plane, wc, 64);
1708 }
1709 }
1710 }
1711
1712 for (int bp = 0; bp < nBypass; ++bp) {
1713 const float *in_plane = &inp.bypass[bp][0];
1714 for (int t = 0; t < ocl; ++t) {
1715 const float wc = w_mixer_bp[(size_t)bp * nf + (oc0 + t)];
1716 simd_add_scaled(macc[t], in_plane, wc, 64);
1717 }
1718 }
1719 });
1720 auto tMix1 = Clock::now();
1721 profile.fullMixerAccumUs += us(tMix0, tMix1);
1722
1723 // 3) Build post-ReLU mixer cache.
1724 auto tRelu0 = Clock::now();
1725 for (int oc = 0; oc < nf; ++oc) {
1726 const float g = global_proj_at(oc, inp.global);
1727 for (int sq = 0; sq < 64; ++sq) {
1728 const float v = mixerLinearAccum[(size_t)oc * 64 + sq] + g;
1729 mixerReluCache[(size_t)oc * 64 + sq] = std::max(0.0f, v);
1730 }
1731 }
1732 auto tRelu1 = Clock::now();
1733 profile.fullMixerReluUs += us(tRelu0, tRelu1);
1734
1735 // 4) Build only active expert bottleneck caches (Top-2 routing).
1736 auto tEx0 = Clock::now();
1737 std::fill(exValid.begin(), exValid.end(), 0);
1738 if (active_count == 2 && active_experts[0] != active_experts[1]) {
1739 parallel_for_indices(2, 2, [&](int i) {
1740 rebuild_expert_cache_from_mixer(active_experts[i]);
1741 });
1742 } else {
1743 for (int i = 0; i < active_count; ++i)
1744 rebuild_expert_cache_from_mixer(active_experts[i]);
1745 }
1746 auto tEx1 = Clock::now();
1747 profile.fullExpertCacheUs += us(tEx0, tEx1);
1748
1749 std::memcpy(oldGlobalV.data(), inp.global, sizeof(float) * nGlobals);
1750
1751 initialized = true;
1752 }
1753
1773 const FactorizedInput &prev,
1774 const int *dirty_branches, int dirty_count,
1775 const int *active_experts, int active_count) {
1776 const auto &shared = shared_weights();
1777 const float *w_mixer_br = shared.mixerWBr.data();
1778 const float *w_mixer_bp = shared.mixerWBp.data();
1779 if (!initialized) {
1780 full_rebuild_accumulators(cur, active_experts, active_count);
1781 return;
1782 }
1783
1784 uint64_t dirty_sq_mask = 0ULL;
1785 std::fill(scratchDeltaRelu.begin(), scratchDeltaRelu.end(), 0.0f);
1786 std::fill(scratchFlatDelta.begin(), scratchFlatDelta.end(), 0.0f);
1787
1788 // A) Branch deltas -> mixer linear accumulator.
1789 auto tA0 = Clock::now();
1790 parallel_for_indices(dirty_count, minParallelDirtyHeads, [&](int i) {
1791 const int b = dirty_branches[i];
1792 const float *pin = &cur.branches[b][0][0];
1793 float *new_cache = &scratchDirtyBranches[(size_t)i * (size_t)bd * 64];
1794 branch_forward_with_scratch(b, pin, new_cache,
1795 scratchParallelBranch0[(size_t)i].data(),
1796 scratchParallelBranch1[(size_t)i].data());
1797 });
1798
1799 for (int i = 0; i < dirty_count; ++i) {
1800 const int b = dirty_branches[i];
1801 const float *new_branch =
1802 &scratchDirtyBranches[(size_t)i * (size_t)bd * 64];
1803
1804 float *old_cache = &branchCache[(size_t)b * bd * 64];
1805 for (int c = 0; c < bd; ++c) {
1806 const float *old_plane = &old_cache[(size_t)c * 64];
1807 const float *new_plane = &new_branch[(size_t)c * 64];
1808 float *delta_plane = &scratchBranchDelta[(size_t)c * 64];
1809
1810#pragma GCC ivdep
1811 for (int sq = 0; sq < 64; ++sq)
1812 delta_plane[sq] = new_plane[sq] - old_plane[sq];
1813
1814 for (int sq = 0; sq < 64; ++sq) {
1815 if (std::abs(delta_plane[sq]) > 1e-9f)
1816 dirty_sq_mask |= (1ULL << sq);
1817 }
1818 }
1819
1820 // Accumulate branch deltas across all 64 squares with SIMD-friendly
1821 // loops.
1822 constexpr int kTileOC = 8;
1823 const int ocTiles = (nf + kTileOC - 1) / kTileOC;
1824 for (int tile = 0; tile < ocTiles; ++tile) {
1825 const int oc0 = tile * kTileOC;
1826 const int ocl = std::min(kTileOC, nf - oc0);
1827 for (int c = 0; c < bd; ++c) {
1828 const float *ds = &scratchBranchDelta[(size_t)c * 64];
1829 for (int t = 0; t < ocl; ++t) {
1830 const float w = w_mixer_br[((size_t)b * nf + (oc0 + t)) * bd + c];
1831 float *macc = &mixerLinearAccum[(size_t)(oc0 + t) * 64];
1832 simd_add_scaled(macc, ds, w, 64);
1833 }
1834 }
1835 }
1836
1837 std::memcpy(old_cache, new_branch, (size_t)bd * 64 * sizeof(float));
1838 }
1839 auto tA1 = Clock::now();
1840 profile.incBranchDeltaUs += us(tA0, tA1);
1841
1842 // B) Bypass deltas -> mixer linear accumulator.
1843 auto tB0 = Clock::now();
1844 std::array<uint8_t, kMaxBypass> bypass_dirty{};
1845 for (int bp = 0; bp < nBypass; ++bp) {
1846 bool has_delta = false;
1847 for (int sq = 0; sq < 64; ++sq) {
1848 const float delta = cur.bypass[bp][sq] - prev.bypass[bp][sq];
1849 scratchBypassDelta[(size_t)bp * 64 + sq] = delta;
1850 if (std::abs(delta) > 1e-9f) {
1851 has_delta = true;
1852 dirty_sq_mask |= (1ULL << sq);
1853 }
1854 }
1855 bypass_dirty[(size_t)bp] = has_delta ? 1u : 0u;
1856 }
1857
1858 // Accumulate only dirty bypass planes across all 64 squares.
1859 {
1860 constexpr int kTileOC = 8;
1861 const int ocTiles = (nf + kTileOC - 1) / kTileOC;
1862 for (int tile = 0; tile < ocTiles; ++tile) {
1863 const int oc0 = tile * kTileOC;
1864 const int ocl = std::min(kTileOC, nf - oc0);
1865 for (int bp = 0; bp < nBypass; ++bp) {
1866 if (!bypass_dirty[(size_t)bp])
1867 continue;
1868 const float *ds = &scratchBypassDelta[(size_t)bp * 64];
1869 for (int t = 0; t < ocl; ++t) {
1870 const float w = w_mixer_bp[(size_t)bp * nf + (oc0 + t)];
1871 float *macc = &mixerLinearAccum[(size_t)(oc0 + t) * 64];
1872 simd_add_scaled(macc, ds, w, 64);
1873 }
1874 }
1875 }
1876 }
1877 auto tB1 = Clock::now();
1878 profile.incBypassDeltaUs += us(tB0, tB1);
1879
1880 // C) If globals changed, all squares are dirty (global is broadcast over
1881 // board).
1882 auto tCD0 = Clock::now();
1883 bool globals_changed = false;
1884 for (int g = 0; g < nGlobals; ++g) {
1885 if (std::abs(cur.global[g] - oldGlobalV[(size_t)g]) > 1e-6f) {
1886 globals_changed = true;
1887 break;
1888 }
1889 }
1890 auto tCD1 = Clock::now();
1891 profile.incGlobalReluUs += us(tCD0, tCD1);
1892 std::memcpy(oldGlobalV.data(), cur.global, sizeof(float) * nGlobals);
1893 if (globals_changed)
1894 dirty_sq_mask = ~0ULL;
1895
1896 for (int oc = 0; oc < nf; ++oc)
1897 scratchGproj[(size_t)oc] = global_proj_at(oc, cur.global);
1898
1899 // D) Compute post-ReLU deltas using ctz bit-scan over dirty squares.
1900 bool any_mixer_delta = false;
1901 uint64_t delta_sq_mask = 0ULL;
1902 for (int oc = 0; oc < nf; ++oc) {
1903 const float gb = scratchGproj[(size_t)oc];
1904 uint64_t mask = dirty_sq_mask;
1905 while (mask) {
1906 const int sq = __builtin_ctzll(mask);
1907 mask &= (mask - 1);
1908
1909 const size_t idx = (size_t)oc * 64 + sq;
1910 const float oldv = mixerReluCache[idx];
1911 const float newv = std::max(0.0f, mixerLinearAccum[idx] + gb);
1912 const float d = newv - oldv;
1913 scratchDeltaRelu[idx] = d;
1914 mixerReluCache[idx] = newv;
1915 if (d != 0.0f) {
1916 any_mixer_delta = true;
1917 delta_sq_mask |= (1ULL << sq);
1918 }
1919 }
1920 }
1921
1922 int dirty_sq_idx[64];
1923 int dirty_sq_count = 0;
1924 int dirty_region_idx[kPool2x2Regions];
1925 int dirty_region_count = 0;
1926 uint32_t region_mask = 0;
1927 uint64_t work = delta_sq_mask;
1928 while (work) {
1929 const int sq = __builtin_ctzll(work);
1930 work &= (work - 1);
1931 dirty_sq_idx[dirty_sq_count++] = sq;
1932 region_mask |= (1u << pool2x2_region_from_sq(sq));
1933 }
1934 const bool dense_dirty = dirty_sq_count >= denseDirtySqThreshold;
1935
1936 while (region_mask) {
1937 const int region = __builtin_ctz(region_mask);
1938 region_mask &= (region_mask - 1);
1939 dirty_region_idx[dirty_region_count++] = region;
1940 }
1941
1942 // D) Expert bottleneck deltas: update only active Top-2 expert caches.
1943 if (!any_mixer_delta) {
1944 std::array<double, kMaxExperts> reb_us{};
1945 parallel_for_indices(active_count, minParallelActiveExperts, [&](int i) {
1946 const int e = active_experts[i];
1947 if (!exValid[(size_t)e]) {
1948 const auto tReb0 = Clock::now();
1950 const auto tReb1 = Clock::now();
1951 reb_us[(size_t)i] = us(tReb0, tReb1);
1952 }
1953 });
1954 for (int i = 0; i < active_count; ++i)
1955 profile.incExpertCacheRebuildUs += reb_us[(size_t)i];
1956 return;
1957 }
1958
1959 std::array<uint8_t, kMaxExperts> new_valid{};
1960 std::array<int, kMaxExperts> processed_e{};
1961 std::array<double, kMaxExperts> reb_us{};
1962 std::array<double, kMaxExperts> bott_us{};
1963 std::array<double, kMaxExperts> hidden_us{};
1964
1965 parallel_for_indices(active_count, minParallelActiveExperts, [&](int i) {
1966 const int e = active_experts[i];
1967 processed_e[(size_t)i] = e;
1968
1969 if (!exValid[(size_t)e]) {
1970 const auto tReb0 = Clock::now();
1972 const auto tReb1 = Clock::now();
1973 reb_us[(size_t)i] = us(tReb0, tReb1);
1974 return;
1975 }
1976
1977 auto &pre = exPreAccum[(size_t)e];
1978 auto &rel = exReluCache[(size_t)e];
1979 auto &hacc = hiddenAcc[(size_t)e];
1980 auto &gap = exGapCache[(size_t)e];
1981 auto &pool16 = exPool16Cache[(size_t)e];
1982 const auto &ex = shared.experts[(size_t)e];
1983 const int flat_sz = ebo * 64;
1984 float *localFlatDelta = scratchParallelExpertDelta[(size_t)e].data();
1985 std::array<float, kMaxExpertBottleneck> localGapDelta{};
1986 std::array<float, (size_t)kMaxExpertBottleneck * kPool2x2Regions>
1987 localPoolDelta{};
1989 std::fill(localFlatDelta, localFlatDelta + flat_sz, 0.0f);
1990
1991 const auto tBott0 = Clock::now();
1993 constexpr float inv64 = 1.0f / 64.0f;
1994 for (int bo = 0; bo < ebo; ++bo) {
1995 float *pbo = &pre[(size_t)bo * 64];
1996 if (dense_dirty) {
1997 const float *w = &ex.wConv[(size_t)bo * nf];
1998 for (int ic = 0; ic < nf; ++ic) {
1999 const float wc = w[ic];
2000 const float *dr = &scratchDeltaRelu[(size_t)ic * 64];
2002 for (int sq = 0; sq < 64; ++sq)
2003 pbo[sq] += dr[sq] * wc;
2004 }
2005 } else {
2006 const float *w = &ex.wConv[(size_t)bo * nf];
2007 for (int k = 0; k < dirty_sq_count; ++k) {
2008 const int sq = dirty_sq_idx[k];
2009 float delta = 0.0f;
2010 for (int ic = 0; ic < nf; ++ic)
2011 delta += scratchDeltaRelu[(size_t)ic * 64 + sq] * w[ic];
2012 pbo[sq] += delta;
2013 }
2014 }
2015
2016 float *rbo = &rel[(size_t)bo * 64];
2017 const float *pbo_ro = &pre[(size_t)bo * 64];
2018 float ch_sum_delta = 0.0f;
2019 if (dense_dirty) {
2020 for (int sq = 0; sq < 64; ++sq) {
2021 const float old_flat = rbo[sq];
2022 const float new_flat = std::max(0.0f, pbo_ro[sq]);
2023 rbo[sq] = new_flat;
2024 ch_sum_delta += (new_flat - old_flat);
2025 }
2026 } else {
2027 for (int k = 0; k < dirty_sq_count; ++k) {
2028 const int sq = dirty_sq_idx[k];
2029 const float old_flat = rbo[sq];
2030 const float new_flat = std::max(0.0f, pbo_ro[sq]);
2031 rbo[sq] = new_flat;
2032 ch_sum_delta += (new_flat - old_flat);
2033 }
2034 }
2035 const float gd = ch_sum_delta * inv64;
2036 localGapDelta[(size_t)bo] = gd;
2037 gap[(size_t)bo] += gd;
2038 }
2039 } else if (expertPoolMode == ExpertPoolMode::Flat) {
2040 for (int bo = 0; bo < ebo; ++bo) {
2041 float *pbo = &pre[(size_t)bo * 64];
2042 if (dense_dirty) {
2043 const float *w = &ex.wConv[(size_t)bo * nf];
2044 for (int ic = 0; ic < nf; ++ic) {
2045 const float wc = w[ic];
2046 const float *dr = &scratchDeltaRelu[(size_t)ic * 64];
2048 for (int sq = 0; sq < 64; ++sq)
2049 pbo[sq] += dr[sq] * wc;
2050 }
2051 } else {
2052 const float *w = &ex.wConv[(size_t)bo * nf];
2053 for (int k = 0; k < dirty_sq_count; ++k) {
2054 const int sq = dirty_sq_idx[k];
2055 float delta = 0.0f;
2056 for (int ic = 0; ic < nf; ++ic)
2057 delta += scratchDeltaRelu[(size_t)ic * 64 + sq] * w[ic];
2058 pbo[sq] += delta;
2059 }
2060 }
2061
2062 float *rbo = &rel[(size_t)bo * 64];
2063 const float *pbo_ro = &pre[(size_t)bo * 64];
2064 if (dense_dirty) {
2065 for (int sq = 0; sq < 64; ++sq) {
2066 const float old_flat = rbo[sq];
2067 const float new_flat = std::max(0.0f, pbo_ro[sq]);
2068 rbo[sq] = new_flat;
2069 localFlatDelta[(size_t)bo * 64 + sq] = new_flat - old_flat;
2070 }
2071 } else {
2072 for (int k = 0; k < dirty_sq_count; ++k) {
2073 const int sq = dirty_sq_idx[k];
2074 const float old_flat = rbo[sq];
2075 const float new_flat = std::max(0.0f, pbo_ro[sq]);
2076 rbo[sq] = new_flat;
2077 localFlatDelta[(size_t)bo * 64 + sq] = new_flat - old_flat;
2078 }
2079 }
2080 }
2082 constexpr float inv4 = 0.25f;
2083 for (int bo = 0; bo < ebo; ++bo) {
2084 float *pbo = &pre[(size_t)bo * 64];
2085 if (dense_dirty) {
2086 const float *w = &ex.wConv[(size_t)bo * nf];
2087 for (int ic = 0; ic < nf; ++ic) {
2088 const float wc = w[ic];
2089 const float *dr = &scratchDeltaRelu[(size_t)ic * 64];
2091 for (int sq = 0; sq < 64; ++sq)
2092 pbo[sq] += dr[sq] * wc;
2093 }
2094 } else {
2095 const float *w = &ex.wConv[(size_t)bo * nf];
2096 for (int k = 0; k < dirty_sq_count; ++k) {
2097 const int sq = dirty_sq_idx[k];
2098 float delta = 0.0f;
2099 for (int ic = 0; ic < nf; ++ic)
2100 delta += scratchDeltaRelu[(size_t)ic * 64 + sq] * w[ic];
2101 pbo[sq] += delta;
2102 }
2103 }
2104
2105 float *rbo = &rel[(size_t)bo * 64];
2106 const float *pbo_ro = &pre[(size_t)bo * 64];
2107 if (dense_dirty) {
2108 for (int sq = 0; sq < 64; ++sq) {
2109 const float old_flat = rbo[sq];
2110 const float new_flat = std::max(0.0f, pbo_ro[sq]);
2111 rbo[sq] = new_flat;
2112 const int region = pool2x2_region_from_sq(sq);
2113 localPoolDelta[(size_t)bo * kPool2x2Regions + region] +=
2114 (new_flat - old_flat) * inv4;
2115 }
2116 } else {
2117 for (int k = 0; k < dirty_sq_count; ++k) {
2118 const int sq = dirty_sq_idx[k];
2119 const float old_flat = rbo[sq];
2120 const float new_flat = std::max(0.0f, pbo_ro[sq]);
2121 rbo[sq] = new_flat;
2122 const int region = pool2x2_region_from_sq(sq);
2123 localPoolDelta[(size_t)bo * kPool2x2Regions + region] +=
2124 (new_flat - old_flat) * inv4;
2125 }
2126 }
2127
2128 const int regions_to_loop =
2129 dense_dirty ? kPool2x2Regions : dirty_region_count;
2130 for (int rk = 0; rk < regions_to_loop; ++rk) {
2131 const int region = dense_dirty ? rk : dirty_region_idx[rk];
2132 const size_t pidx = (size_t)bo * kPool2x2Regions + region;
2133 pool16[pidx] += localPoolDelta[pidx];
2134 }
2135 }
2136 } else {
2137 for (int bo = 0; bo < ebo; ++bo) {
2138 float *pbo = &pre[(size_t)bo * 64];
2139 float *rbo = &rel[(size_t)bo * 64];
2140 if (dense_dirty) {
2141 const float *w = &ex.wConv[(size_t)bo * nf];
2142 for (int ic = 0; ic < nf; ++ic) {
2143 const float wc = w[ic];
2144 const float *dr = &scratchDeltaRelu[(size_t)ic * 64];
2146 for (int sq = 0; sq < 64; ++sq)
2147 pbo[sq] += dr[sq] * wc;
2148 }
2149 } else {
2150 const float *w = &ex.wConv[(size_t)bo * nf];
2151 for (int k = 0; k < dirty_sq_count; ++k) {
2152 const int sq = dirty_sq_idx[k];
2153 float delta = 0.0f;
2154 for (int ic = 0; ic < nf; ++ic)
2155 delta += scratchDeltaRelu[(size_t)ic * 64 + sq] * w[ic];
2156 pbo[sq] += delta;
2157 }
2158 }
2159
2160 if (dense_dirty) {
2161 for (int sq = 0; sq < 64; ++sq)
2162 rbo[sq] = std::max(0.0f, pbo[sq]);
2163 } else {
2164 for (int k = 0; k < dirty_sq_count; ++k) {
2165 const int sq = dirty_sq_idx[k];
2166 rbo[sq] = std::max(0.0f, pbo[sq]);
2167 }
2168 }
2169
2170 const int regions_to_loop =
2171 dense_dirty ? kPool2x2Regions : dirty_region_count;
2172 for (int rk = 0; rk < regions_to_loop; ++rk) {
2173 const int region = dense_dirty ? rk : dirty_region_idx[rk];
2174 int r0 = 0, c0 = 0;
2175 pool2x2_region_base(region, r0, c0);
2176 const int sq0 = r0 * 8 + c0;
2177 const int sq1 = sq0 + 1;
2178 const int sq2 = sq0 + 8;
2179 const int sq3 = sq2 + 1;
2180 const float new_pool = std::max(std::max(rbo[sq0], rbo[sq1]),
2181 std::max(rbo[sq2], rbo[sq3]));
2182 const size_t pidx = (size_t)bo * kPool2x2Regions + region;
2183 const float old_pool = pool16[pidx];
2184 pool16[pidx] = new_pool;
2185 localPoolDelta[pidx] = new_pool - old_pool;
2186 }
2187 }
2188 }
2189 const auto tBott1 = Clock::now();
2190 bott_us[(size_t)i] = us(tBott0, tBott1);
2191
2192 const auto tHidden0 = Clock::now();
2194 for (int bo = 0; bo < ebo; ++bo) {
2195 const float delta = localGapDelta[(size_t)bo];
2196 if (std::abs(delta) < 1e-9f)
2197 continue;
2198 const float *wt = &ex.wHGT[(size_t)bo * eh];
2199 for (int out = 0; out < eh; ++out)
2200 hacc[(size_t)out] += delta * wt[out];
2201 }
2202 } else if (expertPoolMode == ExpertPoolMode::Flat) {
2203 if (dense_dirty) {
2204 for (int fi = 0; fi < flat_sz; ++fi) {
2205 const float delta = localFlatDelta[(size_t)fi];
2206 if (std::abs(delta) < 1e-7f)
2207 continue;
2208 const float *wt = &ex.wHT[(size_t)fi * eh];
2209#pragma GCC ivdep
2210 for (int out = 0; out < eh; ++out)
2211 hacc[(size_t)out] += delta * wt[out];
2212 }
2213 } else {
2214 for (int bo = 0; bo < ebo; ++bo) {
2215 for (int k = 0; k < dirty_sq_count; ++k) {
2216 const int sq = dirty_sq_idx[k];
2217 const int fi = bo * 64 + sq;
2218 const float delta = localFlatDelta[(size_t)fi];
2219 if (std::abs(delta) < 1e-7f)
2220 continue;
2221 const float *wt = &ex.wHT[(size_t)fi * eh];
2222 for (int out = 0; out < eh; ++out)
2223 hacc[(size_t)out] += delta * wt[out];
2224 }
2225 }
2226 }
2228 const int regions_to_loop =
2229 dense_dirty ? kPool2x2Regions : dirty_region_count;
2230 for (int bo = 0; bo < ebo; ++bo) {
2231 for (int rk = 0; rk < regions_to_loop; ++rk) {
2232 const int region = dense_dirty ? rk : dirty_region_idx[rk];
2233 const int pidx = bo * kPool2x2Regions + region;
2234 const float delta = localPoolDelta[(size_t)pidx];
2235 if (std::abs(delta) < 1e-9f)
2236 continue;
2237 const float *wt = &ex.wH16T[(size_t)pidx * eh];
2238#pragma GCC ivdep
2239 for (int out = 0; out < eh; ++out)
2240 hacc[(size_t)out] += delta * wt[out];
2241 }
2242 }
2243 } else {
2244 const int regions_to_loop =
2245 dense_dirty ? kPool2x2Regions : dirty_region_count;
2246 for (int bo = 0; bo < ebo; ++bo) {
2247 for (int rk = 0; rk < regions_to_loop; ++rk) {
2248 const int region = dense_dirty ? rk : dirty_region_idx[rk];
2249 const int pidx = bo * kPool2x2Regions + region;
2250 const float delta = localPoolDelta[(size_t)pidx];
2251 if (std::abs(delta) < 1e-9f)
2252 continue;
2253 const float *wt = &ex.wH16T[(size_t)pidx * eh];
2254#pragma GCC ivdep
2255 for (int out = 0; out < eh; ++out)
2256 hacc[(size_t)out] += delta * wt[out];
2257 }
2258 }
2259 }
2260 const auto tHidden1 = Clock::now();
2261 hidden_us[(size_t)i] = us(tHidden0, tHidden1);
2262 });
2263
2264 for (int i = 0; i < active_count; ++i) {
2265 profile.incExpertCacheRebuildUs += reb_us[(size_t)i];
2266 profile.incExpertBottleneckUs += bott_us[(size_t)i];
2267 profile.incHiddenDeltaUs += hidden_us[(size_t)i];
2268 new_valid[(size_t)processed_e[(size_t)i]] = 1;
2269 }
2270 exValid = new_valid;
2271 }
2272
2284 void run_active_expert(int e, float out_wdl[3]) {
2285 out_wdl[0] = out_wdl[1] = out_wdl[2] = 0.0f;
2286
2287 const auto &ex = shared_weights().experts[(size_t)e];
2288 const auto &hacc = hiddenAcc[(size_t)e];
2289
2290 for (int h = 0; h < eh; ++h)
2291 scratchHidden[(size_t)h] = std::max(0.0f, hacc[(size_t)h]);
2292
2293 float wdl_logits[3] = {ex.bWdl[0], ex.bWdl[1], ex.bWdl[2]};
2294 for (int o = 0; o < 3; ++o) {
2295 const float *w = &ex.wWdl[(size_t)o * eh];
2296 for (int h = 0; h < eh; ++h)
2297 wdl_logits[o] += scratchHidden[(size_t)h] * w[h];
2298 }
2299
2300 const float mx =
2301 std::max(wdl_logits[0], std::max(wdl_logits[1], wdl_logits[2]));
2302 float s = 0.0f;
2303 float wdl_prob[3];
2304 for (int o = 0; o < 3; ++o) {
2305 wdl_prob[o] = std::exp(wdl_logits[o] - mx);
2306 s += wdl_prob[o];
2307 }
2308 if (s > 1e-20f) {
2309 const float inv = 1.0f / s;
2310 for (float &p : wdl_prob)
2311 p *= inv;
2312 }
2313
2314 for (int o = 0; o < 3; ++o)
2315 out_wdl[o] = wdl_prob[o];
2316 }
2317
2331 void run_top2_experts(int e0, int e1, float w0, float w1, float out_wdl[3]) {
2332 float wdl0[3] = {0.0f, 0.0f, 0.0f};
2333 float wdl1[3] = {0.0f, 0.0f, 0.0f};
2334 run_active_expert(e0, wdl0);
2335 run_active_expert(e1, wdl1);
2336 for (int o = 0; o < 3; ++o)
2337 out_wdl[o] = w0 * wdl0[o] + w1 * wdl1[o];
2338 }
2339};
Feature extraction logic for generating packed factorized features.
static const char * expert_pool_mode_name(ExpertPoolMode mode)
Definition MoECacheModel.hpp:835
static constexpr int kDefaultMixerOut
Definition MoECacheModel.hpp:87
static void pool2x2_region_base(int region, int &r0, int &c0)
Definition MoECacheModel.hpp:855
static constexpr int kMaxGlobals
Definition MoECacheModel.hpp:111
static bool is_slider_dirty(const Board &board, const Board &old_board, Color side, PieceType pt)
Checks if a slider piece type (Bishop, Rook, Queen) has become "dirty".
Definition MoECacheModel.hpp:191
static constexpr int NET_EXPERT_HIDDEN
Definition MoECacheModel.hpp:100
static int pool2x2_region_from_sq(int sq)
Definition MoECacheModel.hpp:849
static constexpr int kDefaultExpertHidden
Definition MoECacheModel.hpp:92
static double us(Clock::time_point a, Clock::time_point b)
Definition MoECacheModel.hpp:132
static constexpr int kMaxExpertHidden
Definition MoECacheModel.hpp:114
static void conv1x1_relu(const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, const float *HOT_RESTRICT b, float *HOT_RESTRICT out, int ic, int oc)
Computes a 1x1 convolution (pointwise convolution) across the spatial grid, followed by ReLU.
Definition MoECacheModel.hpp:685
static void simd_add_scaled(float *HOT_RESTRICT dst, const float *HOT_RESTRICT src, float scale, int n)
SIMD-optimized fused multiply-add (FMA) for adding a scaled vector to a destination vector....
Definition MoECacheModel.hpp:345
static constexpr int kMaxExperts
Definition MoECacheModel.hpp:112
ExpertPoolMode
Defines how spatial features are pooled before entering the fully-connected expert networks.
Definition MoECacheModel.hpp:826
@ Pool2x2Max
Definition MoECacheModel.hpp:830
@ Gap
Definition MoECacheModel.hpp:828
@ Flat
Definition MoECacheModel.hpp:827
@ Pool2x2Avg
Definition MoECacheModel.hpp:829
static bool is_slider_branch_rich(int branch_idx)
Determines if a specific branch corresponds to a slider piece type (Bishop, Rook, Queen).
Definition MoECacheModel.hpp:146
static void conv3x3_relu_bd16_dispatch(const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, const float *HOT_RESTRICT b, float *HOT_RESTRICT out, int ic)
Definition MoECacheModel.hpp:501
static constexpr int kInputBypassPlanes
Definition MoECacheModel.hpp:103
static constexpr int NET_GLOBALS
Definition MoECacheModel.hpp:97
static constexpr int kMaxBypass
Definition MoECacheModel.hpp:110
static T * assume_aligned_32(T *ptr)
Hints the compiler that a pointer is aligned to a 32-byte boundary (for AVX operations).
Definition MoECacheModel.hpp:357
static constexpr int kMaxMixerOut
Definition MoECacheModel.hpp:109
static constexpr int kDefaultGlobals
Definition MoECacheModel.hpp:89
static void conv3x3_relu(const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, const float *HOT_RESTRICT b, float *HOT_RESTRICT out, int ic, int oc)
Computes a full 3x3 convolution from ic input channels to oc output channels, followed by ReLU.
Definition MoECacheModel.hpp:485
static constexpr int kDefaultBranchDim
Definition MoECacheModel.hpp:86
#define FORCE_VECTORIZE
Definition MoECacheModel.hpp:123
static void conv3x3_single_out_relu(const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, float b, float *HOT_RESTRICT out, int ic)
Computes a 3x3 convolution across multiple input channels (ic) to produce a single output channel,...
Definition MoECacheModel.hpp:457
static std::array< uint8_t, 12 > build_dirty_mask(const Board &old_board, const Board &new_board, const Move &mv)
Analyzes a chess move to determine exactly which of the 12 spatial branches need recomputing.
Definition MoECacheModel.hpp:257
static constexpr int kPool2x2Regions
Definition MoECacheModel.hpp:833
static void depthwise_conv3x3_relu(const float *HOT_RESTRICT in, const float *HOT_RESTRICT w, const float *HOT_RESTRICT b, float *HOT_RESTRICT out, int channels)
Computes a depthwise 3x3 convolution followed by a ReLU activation.
Definition MoECacheModel.hpp:649
static constexpr int NET_MIXER_OUT
Definition MoECacheModel.hpp:95
static void conv3x3_accumulate_plane(float *HOT_RESTRICT out_plane, const float *HOT_RESTRICT in_plane, const float *HOT_RESTRICT wk)
Computes a full 3x3 2D convolution for a single input channel and accumulates it into out_plane.
Definition MoECacheModel.hpp:385
static constexpr int kMaxBranchDim
Definition MoECacheModel.hpp:108
static constexpr int kInputGlobals
Definition MoECacheModel.hpp:104
static constexpr int kDefaultExpertBottleneck
Definition MoECacheModel.hpp:91
#define HOT_RESTRICT
Definition MoECacheModel.hpp:129
static constexpr int kMaxExpertBottleneck
Definition MoECacheModel.hpp:113
static constexpr int kMixerOcTile
Definition MoECacheModel.hpp:115
std::chrono::high_resolution_clock Clock
Definition MoECacheModel.hpp:84
static constexpr int kDefaultBypass
Definition MoECacheModel.hpp:88
static constexpr int NET_BRANCH_DIM
Definition MoECacheModel.hpp:94
static constexpr int NET_BYPASS
Definition MoECacheModel.hpp:96
static int active_channels_for_branch(int branch_idx)
Returns the number of active input feature channels for a given branch.
Definition MoECacheModel.hpp:167
static void mark_head(std::array< uint8_t, 12 > &m, Color s, PieceType pt)
Marks a specific piece type for a specific color as "dirty" in the dirty mask.
Definition MoECacheModel.hpp:223
static constexpr int NET_EXPERTS
Definition MoECacheModel.hpp:98
static constexpr int kDefaultExperts
Definition MoECacheModel.hpp:90
static constexpr int NET_EXPERT_BOTTLENECK
Definition MoECacheModel.hpp:99
static bool branch_planes_changed(const FactorizedInput &a, const FactorizedInput &b, int branch_idx)
Performs a fast byte-level comparison to see if two spatial branch inputs are identical.
Definition MoECacheModel.hpp:328
chess::Move Move
Alias for chess::Move.
Definition Types.h:15
chess::PieceType PieceType
Alias for chess::PieceType.
Definition Types.h:20
chess::Color Color
Alias for chess::Color.
Definition Types.h:17
chess::Piece Piece
Alias for chess::Piece.
Definition Types.h:19
chess::Bitboard Bitboard
Alias for chess::Bitboard.
Definition Types.h:21
chess::Square Square
Alias for chess::Square.
Definition Types.h:18
chess::Board Board
Alias for chess::Board.
Definition Types.h:14
Definition MoECacheModel.hpp:860
int nGlobals
Definition MoECacheModel.hpp:875
unsigned seed
Definition MoECacheModel.hpp:863
int branchConvLayers
Definition MoECacheModel.hpp:871
int minParallelActiveExperts
Definition MoECacheModel.hpp:866
int minParallelDirtyHeads
Definition MoECacheModel.hpp:865
int nThreads
Definition MoECacheModel.hpp:864
int nBypass
Definition MoECacheModel.hpp:874
int expertHidden
Definition MoECacheModel.hpp:878
int nExperts
Definition MoECacheModel.hpp:876
int nPlies
Definition MoECacheModel.hpp:862
int denseDirtySqThreshold
Definition MoECacheModel.hpp:867
ExpertPoolMode expertPoolMode
Definition MoECacheModel.hpp:868
int expertBottleneck
Definition MoECacheModel.hpp:877
int mixerOut
Definition MoECacheModel.hpp:873
int branchDim
Definition MoECacheModel.hpp:872
bool routeSlowGlobals
Definition MoECacheModel.hpp:869
int nGames
Definition MoECacheModel.hpp:861
Definition MoECacheModel.hpp:881
double incRouteUs
Definition MoECacheModel.hpp:899
double incBypassDeltaUs
Definition MoECacheModel.hpp:911
double incBranchDeltaUs
Definition MoECacheModel.hpp:910
double incExpertCacheRebuildUs
Definition MoECacheModel.hpp:915
double incExpertBottleneckUs
Definition MoECacheModel.hpp:913
double fullUs
Definition MoECacheModel.hpp:883
double incUs
Definition MoECacheModel.hpp:885
long long incPlies
Definition MoECacheModel.hpp:886
long long incEvents
Definition MoECacheModel.hpp:888
double prepFeatureUs
Definition MoECacheModel.hpp:892
double fullUpdateUs
Definition MoECacheModel.hpp:897
double incExpertUs
Definition MoECacheModel.hpp:901
long long fullPlies
Definition MoECacheModel.hpp:884
double incGlobalReluUs
Definition MoECacheModel.hpp:912
double incHiddenDeltaUs
Definition MoECacheModel.hpp:914
double avgDirtyHeads() const
Definition MoECacheModel.hpp:921
std::string name
Definition MoECacheModel.hpp:882
double fullExpertCacheUs
Definition MoECacheModel.hpp:907
double fullNps() const
Definition MoECacheModel.hpp:917
double fullMixerAccumUs
Definition MoECacheModel.hpp:905
double fullRouteUs
Definition MoECacheModel.hpp:896
long long incDirtyHeads
Definition MoECacheModel.hpp:887
double prepPlayUs
Definition MoECacheModel.hpp:891
double fullBranchForwardUs
Definition MoECacheModel.hpp:904
double incUpdateUs
Definition MoECacheModel.hpp:900
double fullExpertUs
Definition MoECacheModel.hpp:898
double fullMixerReluUs
Definition MoECacheModel.hpp:906
double incNps() const
Definition MoECacheModel.hpp:920
double prepDirtyUs
Definition MoECacheModel.hpp:893
Definition MoECacheModel.hpp:926
std::array< float, NET_BRANCH_DIM > b
Definition MoECacheModel.hpp:930
int oc
Definition MoECacheModel.hpp:928
int ic
Definition MoECacheModel.hpp:927
std::array< float,(size_t) 5 *NET_BRANCH_DIM *9 > w
Definition MoECacheModel.hpp:929
Definition MoECacheModel.hpp:933
BranchLayer l1
Definition MoECacheModel.hpp:935
BranchLayer l2
Definition MoECacheModel.hpp:936
BranchLayer l0
Definition MoECacheModel.hpp:934
Definition MoECacheModel.hpp:939
std::array< float, 3 > bWdl
Definition MoECacheModel.hpp:970
std::array< float,(size_t) NET_EXPERT_BOTTLENECK *NET_EXPERT_HIDDEN > wHGT
Definition MoECacheModel.hpp:950
std::array< float,(size_t) NET_EXPERT_HIDDEN *NET_EXPERT_BOTTLENECK > wHG
Definition MoECacheModel.hpp:946
std::array< float,(size_t) NET_EXPERT_HIDDEN *NET_EXPERT_BOTTLENECK *kPool2x2Regions > wH16
Definition MoECacheModel.hpp:954
std::array< float,(size_t) NET_EXPERT_BOTTLENECK *64 *NET_EXPERT_HIDDEN > wHT
Definition MoECacheModel.hpp:964
std::array< float, NET_EXPERT_BOTTLENECK > bConv
Definition MoECacheModel.hpp:942
std::array< float,(size_t) NET_EXPERT_BOTTLENECK *NET_MIXER_OUT > wConv
Definition MoECacheModel.hpp:941
std::array< float,(size_t) NET_EXPERT_BOTTLENECK *kPool2x2Regions *NET_EXPERT_HIDDEN > wH16T
Definition MoECacheModel.hpp:958
std::array< float,(size_t) 3 *NET_EXPERT_HIDDEN > wWdl
Definition MoECacheModel.hpp:969
std::array< float, NET_EXPERT_HIDDEN > bH
Definition MoECacheModel.hpp:966
std::array< float,(size_t) NET_EXPERT_HIDDEN *NET_EXPERT_BOTTLENECK *64 > wH
Definition MoECacheModel.hpp:962
Definition FactorizedFeatureExtractor.hpp:47
float branches[12][MAX_BRANCH_PLANES][64]
Definition FactorizedFeatureExtractor.hpp:50
float bypass[12][64]
Definition FactorizedFeatureExtractor.hpp:51
float global[32]
Definition FactorizedFeatureExtractor.hpp:52
Definition MoECacheModel.hpp:1112
double fullExpertCacheUs
Definition MoECacheModel.hpp:1116
double incExpertBottleneckUs
Definition MoECacheModel.hpp:1121
double incGlobalReluUs
Definition MoECacheModel.hpp:1120
double incBranchDeltaUs
Definition MoECacheModel.hpp:1118
double incBypassDeltaUs
Definition MoECacheModel.hpp:1119
double fullBranchForwardUs
Definition MoECacheModel.hpp:1113
double incHiddenDeltaUs
Definition MoECacheModel.hpp:1122
double fullMixerAccumUs
Definition MoECacheModel.hpp:1114
double incExpertCacheRebuildUs
Definition MoECacheModel.hpp:1123
double fullMixerReluUs
Definition MoECacheModel.hpp:1115
Thread-local state for incremental, lightning-fast neural network inference.
Definition MoECacheModel.hpp:1051
static constexpr int nExperts
Definition MoECacheModel.hpp:1063
void run_top2_experts(int e0, int e1, float w0, float w1, float out_wdl[3])
Combines the output of the top 2 routed experts based on their routing weights.
Definition MoECacheModel.hpp:2331
std::unique_ptr< PersistentThreadPool > threadPool
Definition MoECacheModel.hpp:1070
static constexpr int nGlobals
Definition MoECacheModel.hpp:1062
std::array< float,(size_t) kMaxBranchDim *64 > scratchBranchDelta
Definition MoECacheModel.hpp:1102
SharedMoEWeights & mutable_owned_weights()
Definition MoECacheModel.hpp:1136
std::array< std::array< float, kMaxExpertHidden >, kMaxExperts > hiddenAcc
Definition MoECacheModel.hpp:1083
static constexpr int bd
Definition MoECacheModel.hpp:1059
std::array< std::array< float,(size_t) kMaxExpertBottleneck *64 >, kMaxExperts > exPreAccum
Definition MoECacheModel.hpp:1078
std::shared_ptr< SharedMoEWeights > ownedWeights
Definition MoECacheModel.hpp:1068
std::array< std::array< float,(size_t) kMaxExpertBottleneck *64 >, kMaxExperts > exReluCache
Definition MoECacheModel.hpp:1080
static void validate_fixed_architecture(const BenchConfig &cfg)
Definition MoECacheModel.hpp:1306
void rebuild_hidden_acc_from_pool2x2(int e, bool max_pool)
Definition MoECacheModel.hpp:1524
void branch_forward_bd16_fast(const Branch &br, const float *HOT_RESTRICT in_planes, float *HOT_RESTRICT out, float *HOT_RESTRICT mid_plane, float *HOT_RESTRICT l1_accum)
Definition MoECacheModel.hpp:1407
void fill_random(unsigned seed)
Definition MoECacheModel.hpp:1342
std::array< float,(size_t) kMaxMixerOut *64 > mixerReluCache
Definition MoECacheModel.hpp:1075
std::array< uint8_t, kMaxExperts > exValid
Definition MoECacheModel.hpp:1081
void update_incremental(const FactorizedInput &cur, const FactorizedInput &prev, const int *dirty_branches, int dirty_count, const int *active_experts, int active_count)
Performs an incremental network update by computing and applying only the differences.
Definition MoECacheModel.hpp:1772
std::array< float, kMaxMixerOut > scratchGproj
Definition MoECacheModel.hpp:1104
std::array< float,(size_t) kMaxBranchDim *64 > scratchT0
Definition MoECacheModel.hpp:1092
void rebuild_hidden_acc_from_gap(int e)
Definition MoECacheModel.hpp:1495
float global_proj_at(int oc, const float *g) const
Definition MoECacheModel.hpp:1565
std::array< float,(size_t) 12 *kMaxBranchDim *64 > scratchDirtyBranches
Definition MoECacheModel.hpp:1100
std::array< float,(size_t) kMaxMixerOut *64 > mixerLinearAccum
Definition MoECacheModel.hpp:1074
int minParallelDirtyHeads
Definition MoECacheModel.hpp:1054
std::array< float,(size_t) 12 *kMaxBranchDim *64 > branchCache
Definition MoECacheModel.hpp:1073
long long total_weights() const
Definition MoECacheModel.hpp:1212
void branch_forward_with_scratch(int b, const float *in_planes, float *out, float *scratch0, float *scratch1)
Definition MoECacheModel.hpp:1433
int nThreads
Definition MoECacheModel.hpp:1053
std::array< std::array< float,(size_t) kMaxBranchDim *64 >, 12 > scratchParallelBranch1
Definition MoECacheModel.hpp:1098
void init(const BenchConfig &cfg)
Definition MoECacheModel.hpp:1336
std::array< std::array< float,(size_t) kMaxExpertBottleneck *64 >, kMaxExperts > scratchParallelExpertDelta
Definition MoECacheModel.hpp:1109
void top2_experts(const float *global, int &e0, int &e1, float &w0, float &w1) const
Definition MoECacheModel.hpp:1574
ExpertPoolMode expertPoolMode
Definition MoECacheModel.hpp:1057
std::array< float,(size_t) kMaxBypass *64 > scratchBypassDelta
Definition MoECacheModel.hpp:1103
void branch_forward(int b, const float *in_planes, float *out)
Definition MoECacheModel.hpp:1471
void parallel_for_indices(int n, int min_parallel_n, Fn &&fn)
Definition MoECacheModel.hpp:1197
void reset_runtime_state()
Definition MoECacheModel.hpp:1143
bool routeSlowGlobals
Definition MoECacheModel.hpp:1058
std::array< float, kMaxExpertHidden > scratchHidden
Definition MoECacheModel.hpp:1110
PhaseProfile profile
Definition MoECacheModel.hpp:1126
std::array< float,(size_t) kMaxExpertBottleneck *64 > scratchFlatDelta
Definition MoECacheModel.hpp:1107
long long experts_total_weights() const
Definition MoECacheModel.hpp:1275
static constexpr int nf
Definition MoECacheModel.hpp:1060
std::array< std::array< float,(size_t) kMaxBranchDim *64 >, 12 > scratchParallelBranch0
Definition MoECacheModel.hpp:1096
static constexpr int nBypass
Definition MoECacheModel.hpp:1061
std::array< float, kMaxGlobals > oldGlobalV
Definition MoECacheModel.hpp:1089
void parallel_for_indices(int n, Fn &&fn)
Definition MoECacheModel.hpp:1208
static constexpr int eh
Definition MoECacheModel.hpp:1065
const SharedMoEWeights * weights
Definition MoECacheModel.hpp:1067
void copy_weights_from(const MoEDoubleAccumulator &src)
Definition MoECacheModel.hpp:1179
void reset_profile()
Definition MoECacheModel.hpp:1194
void run_active_expert(int e, float out_wdl[3])
Computes the final hidden layer and WDL output for a single expert.
Definition MoECacheModel.hpp:2284
long long single_expert_weights() const
Definition MoECacheModel.hpp:1251
std::array< float,(size_t) kMaxMixerOut *64 > scratchDeltaRelu
Definition MoECacheModel.hpp:1105
void rebuild_hidden_acc_from_flat(int e)
Definition MoECacheModel.hpp:1476
int branchConvLayers
Definition MoECacheModel.hpp:1052
std::array< float,(size_t) kMaxBranchDim *64 > scratchNewBranch
Definition MoECacheModel.hpp:1094
static constexpr int ebo
Definition MoECacheModel.hpp:1064
int denseDirtySqThreshold
Definition MoECacheModel.hpp:1056
const SharedMoEWeights & shared_weights() const
Definition MoECacheModel.hpp:1130
int minParallelActiveExperts
Definition MoECacheModel.hpp:1055
std::array< float,(size_t) kMaxBranchDim *64 > scratchT1
Definition MoECacheModel.hpp:1093
void rebuild_expert_cache_from_mixer(int e)
Definition MoECacheModel.hpp:1618
void init(const SharedMoEWeights *shared, const BenchConfig &cfg)
Definition MoECacheModel.hpp:1315
std::array< std::array< float, kMaxExpertBottleneck >, kMaxExperts > exGapCache
Definition MoECacheModel.hpp:1085
bool initialized
Definition MoECacheModel.hpp:1128
std::array< std::array< float,(size_t) kMaxExpertBottleneck *kPool2x2Regions >, kMaxExperts > exPool16Cache
Definition MoECacheModel.hpp:1088
void full_rebuild_accumulators(const FactorizedInput &inp, const int *active_experts, int active_count)
Performs a full forward pass of the model, discarding all cache.
Definition MoECacheModel.hpp:1665
long long backbone_weights() const
Definition MoECacheModel.hpp:1297
long long runtime_topk_weights(int topk) const
Definition MoECacheModel.hpp:1301
Contains the globally shared, read-only weights for the Factorized MoE network.
Definition MoECacheModel.hpp:986
std::array< float,(size_t) NET_EXPERTS *NET_GLOBALS > gateW
Definition MoECacheModel.hpp:1008
std::array< float, NET_MIXER_OUT > globalB
Definition MoECacheModel.hpp:1005
std::array< float,(size_t) NET_MIXER_OUT *NET_GLOBALS > globalW
Definition MoECacheModel.hpp:1004
std::array< float, NET_MIXER_OUT > mixerB
Definition MoECacheModel.hpp:1001
std::array< float,(size_t) 12 *NET_MIXER_OUT *NET_BRANCH_DIM > mixerWBr
Definition MoECacheModel.hpp:998
static constexpr int bd
Definition MoECacheModel.hpp:987
static constexpr int nf
Definition MoECacheModel.hpp:988
static constexpr int nExperts
Definition MoECacheModel.hpp:991
std::array< float,(size_t) NET_BYPASS *NET_MIXER_OUT > mixerWBp
Definition MoECacheModel.hpp:1000
void init_architecture(int branchConvLayers)
Definition MoECacheModel.hpp:1013
static constexpr int ebo
Definition MoECacheModel.hpp:992
static constexpr int eh
Definition MoECacheModel.hpp:993
std::array< Expert, NET_EXPERTS > experts
Definition MoECacheModel.hpp:1011
std::array< float, NET_EXPERTS > gateB
Definition MoECacheModel.hpp:1009
static constexpr int nGlobals
Definition MoECacheModel.hpp:990
std::array< Branch, 12 > branches
Definition MoECacheModel.hpp:995
static constexpr int nBypass
Definition MoECacheModel.hpp:989