Training Guide

This guide covers training SHC models from scratch with multi-GPU support.

Basic Training

Single GPU

from shc.models import SHCTransformer, get_config
from shc.training import SHCTrainer, TrainingArgs

# Create model
model = SHCTransformer(get_config('500m'))

# Configure training
args = TrainingArgs(
    output_dir='./output',
    learning_rate=3e-4,
    max_steps=100000,
    batch_size=32,
)

# Create trainer
trainer = SHCTrainer(
    model=model,
    args=args,
    train_dataloader=train_loader,
    eval_dataloader=eval_loader,
)

# Train
trainer.train()

Multi-GPU with DDP

torchrun --nproc_per_node=8 -m shc.scripts.train \
    --model_size 3b \
    --batch_size 32 \
    --learning_rate 3e-4

FSDP for Large Models

For 7B+ parameters, use Fully Sharded Data Parallel:

torchrun --nproc_per_node=8 -m shc.scripts.train \
    --model_size 7b \
    --use_fsdp \
    --mixed_precision bf16

Training Configuration

from shc.training import TrainingArgs

args = TrainingArgs(
    # Output
    output_dir='./output',
    run_name='shc_3b_run1',
    
    # Optimization
    learning_rate=3e-4,
    weight_decay=0.1,
    max_steps=100000,
    warmup_steps=2000,
    
    # Batch Configuration
    batch_size=32,
    gradient_accumulation_steps=8,
    
    # Precision
    mixed_precision='bf16',  # 'fp32', 'fp16', 'bf16'
    
    # Distributed
    use_ddp=True,
    use_fsdp=False,
    
    # Logging
    log_steps=100,
    eval_steps=1000,
    save_steps=5000,
    
    # Reproducibility
    seed=42,
)

Hyperparameters by Scale

Parameter

500M

3B

7B

Learning Rate

3e-4

3e-4

1.5e-4

Batch Size

1024

1024

512

Warmup Steps

2000

2000

4000

Weight Decay

0.1

0.1

0.1

Data Loading

from shc.data import TokenizedDataset, create_dataloader

# Create dataset
dataset = TokenizedDataset(
    data_path='path/to/tokenized_data',
    max_seq_len=2048,
)

# Create distributed dataloader
dataloader = create_dataloader(
    dataset,
    batch_size=32,
    num_workers=8,
    shuffle=True,
)

Gradient Checkpointing

For memory-efficient training:

config = get_config('7b')
config.use_gradient_checkpointing = True
model = SHCTransformer(config)

Mixed Precision Training

from shc.training import setup_distributed

# BF16 is preferred for stability
args = TrainingArgs(mixed_precision='bf16')

# FP16 requires loss scaling
args = TrainingArgs(mixed_precision='fp16')

Checkpointing

# Save checkpoint
trainer.save_checkpoint('checkpoint_50k')

# Load and resume
trainer.load_checkpoint('checkpoint_50k')
trainer.train()

Monitoring

TensorBoard

args = TrainingArgs(
    use_tensorboard=True,
    log_steps=100,
)
tensorboard --logdir=./output

Weights & Biases

args = TrainingArgs(
    use_wandb=True,
    run_name='shc_experiment',
)

SSM Distillation

For O(L) inference, distill trained model:

from shc.training import DistillationTrainer, DistillationConfig

# Load trained teacher
teacher = SHCTransformer.from_pretrained('path/to/teacher')
teacher.eval()

# Create student
from shc.models import SSMStudent
student = SSMStudent.from_teacher_config(teacher.config)

# Configure distillation
config = DistillationConfig(
    max_steps=10000,
    learning_rate=1e-4,
    temperature=2.0,
    alpha_ce=0.5,
    alpha_kd=0.5,
)

# Distill
trainer = DistillationTrainer(teacher, student, config, train_loader)
trainer.train()

# Save student
student.save_pretrained('path/to/student')

Best Practices

  1. Start Small: Verify training works on 500M before scaling

  2. Monitor Spectral Norms: Check get_routing_stats() periodically

  3. Use BF16: More stable than FP16 for SHC training

  4. Gradient Clipping: Default max_grad_norm=1.0 is recommended