|
| | model.model = ChessNetFactorizedMoE(expert_bottleneck=16, mixer_out=512) |
| list | model.planes_list = [torch.randn(1, in_ch, 8, 8) for in_ch in ChessNetFactorizedMoE.PLANES_PER_TYPE] |
| | model.bypass = torch.randn(1, 12, 8, 8) |
| | model.global_v = torch.randn(1, 21) |
| | model.wdl = model(planes_list, bypass, global_v) |
| | model.total_params = sum(p.numel() for p in model.parameters()) |
| | model.branches_params = sum(p.numel() for n, p in model.named_parameters() if "branches" in n) |
| | model.stem_global_params = sum(p.numel() for n, p in model.named_parameters() if "stem_global" in n) |
| | model.pointwise_mixer_params = sum(p.numel() for n, p in model.named_parameters() if "pointwise_mixer" in n) |
| | model.backbone_params = sum(p.numel() for n, p in model.named_parameters() if "branches" in n or "pointwise_mixer" in n or "stem_global" in n) |
| | model.expert_params = sum(p.numel() for n, p in model.named_parameters() if "experts" in n) |