io-chess
UCI chess engine
Loading...
Searching...
No Matches
train Namespace Reference

Classes

class  ChunkedRandomSampler
class  ExpertChunkedSampler
class  ConcatChunkedRandomSampler
class  ConcatExpertChunkedSampler

Functions

 print_header (text)
 format_time (seconds)
 get_device ()
 init_wandb (args, phase_name)
 build_loader (dataset, batch_size, shuffle, workers, custom_sampler=None)
 parse_data_roots (args)
 resolve_data_splits (data_roots)
 build_phase14_datasets (data_roots, n_globals)
 phase_eta_min (phase)
 build_scheduler (optimizer, total_steps, warmup_steps, eta_min)
 split_steps_for_epoch (total_steps, val_splits)
 arch_string (args)
 phase1_export_state_dict (model)
 save_model_checkpoint (path, model, phase, epoch, val_terms, args, extra=None, state_dict_override=None)
 empty_loss_terms ()
 add_loss_terms (sums, total_t, wdl_t, n)
 finalize_loss_terms (sums, total_samples)
 format_loss_terms (tag, terms, color)
 to_planes_list (branches)
 top2_sparse_weights (base_weights)
 unpack_common_batch (batch, device)
 forward_backbone (model, planes_list, bypass, global_feats)
 train_phase1_steps (model, data_iter, num_steps, optimizer, scheduler, device)
 validate_phase1 (model, loader, device)
 copy_expert0_to_all (model)
 run_phase1 (model, args, wandb_run=None)
 train_expert_steps (model, expert_idx, data_iter, num_steps, optimizer, scheduler, device)
 validate_expert (model, expert_idx, loader, device)
 load_checkpoint_into_model (model, checkpoint_path, device)
 run_phase2 (model, args, wandb_run=None)
 train_phase4_steps (model, data_iter, num_steps, optimizer, scheduler, device)
 validate_phase4 (model, loader, device)
 run_phase4 (model, args, wandb_run=None)
 main ()

Variables

bool WANDB_AVAILABLE = True
 C = Colors

Detailed Description

@file train.py
@brief Training script for model.py (factorized MoE).

3-Phase training pipeline:
  Phase 1: Base training with teacher routing weights
  Phase 2: Expert specialization on expert-specific datasets
  Phase 4: Joint fine-tuning with top-2 routing

This script is intentionally styled after train_light_moe.py while using
factorized packed inputs from dataset.py.

Function Documentation

◆ add_loss_terms()

add_loss_terms ( sums,
total_t,
wdl_t,
n )
Here is the caller graph for this function:

◆ arch_string()

arch_string ( args)
Here is the caller graph for this function:

◆ build_loader()

build_loader ( dataset,
batch_size,
shuffle,
workers,
custom_sampler = None )
Here is the caller graph for this function:

◆ build_phase14_datasets()

build_phase14_datasets ( data_roots,
n_globals )
Here is the call graph for this function:
Here is the caller graph for this function:

◆ build_scheduler()

build_scheduler ( optimizer,
total_steps,
warmup_steps,
eta_min )
Here is the caller graph for this function:

◆ copy_expert0_to_all()

copy_expert0_to_all ( model)
Here is the caller graph for this function:

◆ empty_loss_terms()

empty_loss_terms ( )
Here is the caller graph for this function:

◆ finalize_loss_terms()

finalize_loss_terms ( sums,
total_samples )
Here is the caller graph for this function:

◆ format_loss_terms()

format_loss_terms ( tag,
terms,
color )

◆ format_time()

format_time ( seconds)

◆ forward_backbone()

forward_backbone ( model,
planes_list,
bypass,
global_feats )
Here is the caller graph for this function:

◆ get_device()

get_device ( )
Here is the caller graph for this function:

◆ init_wandb()

init_wandb ( args,
phase_name )
Here is the caller graph for this function:

◆ load_checkpoint_into_model()

load_checkpoint_into_model ( model,
checkpoint_path,
device )
Here is the caller graph for this function:

◆ main()

main ( )
Here is the call graph for this function:
Here is the caller graph for this function:

◆ parse_data_roots()

parse_data_roots ( args)
Here is the caller graph for this function:

◆ phase1_export_state_dict()

phase1_export_state_dict ( model)
Build a checkpoint state where expert0 is copied to all experts.
Here is the caller graph for this function:

◆ phase_eta_min()

phase_eta_min ( phase)
Here is the caller graph for this function:

◆ print_header()

print_header ( text)
Here is the caller graph for this function:

◆ resolve_data_splits()

resolve_data_splits ( data_roots)
Here is the caller graph for this function:

◆ run_phase1()

run_phase1 ( model,
args,
wandb_run = None )
Here is the call graph for this function:
Here is the caller graph for this function:

◆ run_phase2()

run_phase2 ( model,
args,
wandb_run = None )
Here is the call graph for this function:
Here is the caller graph for this function:

◆ run_phase4()

run_phase4 ( model,
args,
wandb_run = None )
Here is the call graph for this function:
Here is the caller graph for this function:

◆ save_model_checkpoint()

save_model_checkpoint ( path,
model,
phase,
epoch,
val_terms,
args,
extra = None,
state_dict_override = None )
Here is the caller graph for this function:

◆ split_steps_for_epoch()

split_steps_for_epoch ( total_steps,
val_splits )
Here is the caller graph for this function:

◆ to_planes_list()

to_planes_list ( branches)
Here is the caller graph for this function:

◆ top2_sparse_weights()

top2_sparse_weights ( base_weights)
Here is the caller graph for this function:

◆ train_expert_steps()

train_expert_steps ( model,
expert_idx,
data_iter,
num_steps,
optimizer,
scheduler,
device )
Here is the call graph for this function:
Here is the caller graph for this function:

◆ train_phase1_steps()

train_phase1_steps ( model,
data_iter,
num_steps,
optimizer,
scheduler,
device )
Here is the call graph for this function:
Here is the caller graph for this function:

◆ train_phase4_steps()

train_phase4_steps ( model,
data_iter,
num_steps,
optimizer,
scheduler,
device )
Here is the call graph for this function:
Here is the caller graph for this function:

◆ unpack_common_batch()

unpack_common_batch ( batch,
device )
Here is the call graph for this function:
Here is the caller graph for this function:

◆ validate_expert()

validate_expert ( model,
expert_idx,
loader,
device )
Here is the call graph for this function:
Here is the caller graph for this function:

◆ validate_phase1()

validate_phase1 ( model,
loader,
device )
Here is the call graph for this function:
Here is the caller graph for this function:

◆ validate_phase4()

validate_phase4 ( model,
loader,
device )
Here is the call graph for this function:
Here is the caller graph for this function:

Variable Documentation

◆ C

train.C = Colors

◆ WANDB_AVAILABLE

bool WANDB_AVAILABLE = True