Architecture
The SHC Transformer architecture extends the standard transformer with multi-stream residual connections using sparse orthogonal routing.
Overview
┌─────────────────────────────────────────────────────────┐
│ SHC Transformer │
├─────────────────────────────────────────────────────────┤
│ │
│ Input: token_ids (batch, seq_len) │
│ ↓ │
│ ┌─────────────────────┐ │
│ │ Token Embedding │ vocab_size → hidden_dim │
│ └─────────────────────┘ │
│ ↓ │
│ ┌─────────────────────┐ │
│ │ N × SHC Block │ With orthogonal routing │
│ └─────────────────────┘ │
│ ↓ │
│ ┌─────────────────────┐ │
│ │ RMS Norm │ │
│ └─────────────────────┘ │
│ ↓ │
│ ┌─────────────────────┐ │
│ │ LM Head │ hidden_dim → vocab_size │
│ └─────────────────────┘ │
│ ↓ │
│ Output: logits (batch, seq_len, vocab_size) │
│ │
└─────────────────────────────────────────────────────────┘
SHC Block
Each SHC Block implements Algorithm 1 from the paper:
SHC Block Forward Pass
──────────────────────
Input: x ∈ ℝ^d, layer index l
1. n_eff ← AdaptiveRank(x, l) # Adaptive stream expansion
2. IF n_eff > 1:
a. x̄ ← StreamExpand(x, n_eff) # Expand to n streams
b. α ← softmax(W_α · x̄) # Compute mixing weights
c. H^res ← Σ αᵢ · Q(Aᵢ) # Cayley routing matrix
d. x̄_out ← H^res · x̄ + H^post · f(H^pre · x̄)
↑ ↑ ↑
residual output input
routing routing routing
e. x_out ← Compress(x̄_out, r=1) # Factorized cache
3. ELSE:
x_out ← x + f(x) # Standard residual
Return: x_out
Model Configurations
Size |
Hidden Dim |
Layers |
Heads |
FFN Dim |
Parameters |
|---|---|---|---|---|---|
500M |
1024 |
24 |
16 |
4096 |
~500M |
1B |
2048 |
24 |
16 |
8192 |
~1B |
3B |
2560 |
32 |
32 |
10240 |
~3B |
7B |
4096 |
32 |
32 |
11008 |
~7B |
from shc.models import get_config, SHCTransformer
# Load predefined configuration
config = get_config('3b')
model = SHCTransformer(config)
Core Components
CayleyTransform
Generates orthogonal matrices with exactly \(\rho = 1\):
from shc.layers import CayleyTransform
cayley = CayleyTransform(n=4, init_scale=0.01)
Q = cayley() # 4×4 orthogonal matrix
SparseOrthogonalMixture
Input-dependent mixture of \(k\) orthogonal matrices:
from shc.layers import SparseOrthogonalMixture
routing = SparseOrthogonalMixture(
n=4, # Number of streams
k=2, # Number of orthogonal matrices
hidden_dim=768 # Dimension for computing mixing weights
)
H_res = routing(x) # (batch, n, n) routing matrix
FactorizedKVCache
Low-rank compression for efficient caching:
from shc.layers import FactorizedKVCache
cache = FactorizedKVCache(
n=4, # Number of streams
d=768, # Hidden dimension
r=1 # Factorization rank
)
AdaptiveRankSelector
Layer-wise and input-dependent effective rank:
from shc.layers import AdaptiveRankSelector
selector = AdaptiveRankSelector(n=4, hidden_dim=768)
n_eff = selector(x) # Effective number of streams
Multi-Head Attention
Standard multi-head attention with RoPE positional encoding:
from shc.blocks import MultiHeadAttention
attention = MultiHeadAttention(
hidden_dim=768,
n_heads=12,
max_seq_len=4096,
use_rope=True
)
Feed-Forward Network
SwiGLU activation with learnable gating:
from shc.blocks import FeedForward
ffn = FeedForward(
hidden_dim=768,
ffn_dim=3072 # Typically 4× hidden_dim
)
Generation
Autoregressive generation with KV caching:
output = model.generate(
input_ids,
max_new_tokens=100,
temperature=0.7,
top_k=50,
top_p=0.9,
do_sample=True
)
SSM Distillation
For O(L) inference, distill into a State Space Model:
from shc.models import SSMStudent
from shc.training import DistillationTrainer
# Create student matching teacher dimensions
student = SSMStudent.from_teacher_config(teacher.config)
# Distill
trainer = DistillationTrainer(teacher, student, config, data)
trainer.train()
# Student generates without KV cache
output = student.generate(input_ids, max_new_tokens=100)