Tutorial
This section walks through a detailed example: training a two-layer MLP on a synthetic classification task with per-parameter learning rate tuning via HyperDistill and baselines.
The full script is in code/demo_hyperdistill.py.
An interactive notebook covering all five algorithms and their comparison
is in code/demo_methods.ipynb.
Task Description
Model: MLP 10 → 32 → 5 (Xavier initialization), 517 parameters.
Data: 5 classes, Gaussian clusters around fixed centers. The training set contains 20% label noise; the validation set is clean.
Hyperparameter \(\lambda\): a vector of 517 values — one learning rate per model parameter, passed through
softplusto guarantee positivity.
This choice of hyperparameter is interesting because \(\lambda\) does not appear directly in \(L_{\mathrm{val}}\), so \(g_{\mathrm{FO}} = 0\) and all signal comes from the second-order term.
Model Definition
import jax
import jax.numpy as jnp
def init_mlp(key, in_dim, hidden_dim, out_dim):
k1, k2 = jax.random.split(key)
return {
'w1': jax.random.normal(k1, (in_dim, hidden_dim)) * jnp.sqrt(2.0 / in_dim),
'b1': jnp.zeros(hidden_dim),
'w2': jax.random.normal(k2, (hidden_dim, out_dim)) * jnp.sqrt(2.0 / hidden_dim),
'b2': jnp.zeros(out_dim),
}
def mlp_forward(params, x):
h = jax.nn.relu(x @ params['w1'] + params['b1'])
return h @ params['w2'] + params['b2']
Loss Functions and Inner Step
def cross_entropy_loss(params, batch):
x, y = batch
logits = mlp_forward(params, x)
log_probs = jax.nn.log_softmax(logits, axis=-1)
return -jnp.mean(jnp.sum(log_probs * y, axis=-1))
# Bilevel interface wrapper: (params, hyperparams, batch) -> scalar
def bilevel_loss(params, hyperparams, batch):
return cross_entropy_loss(params, batch)
def make_update_fn(loss_fn):
def update_fn(w, lr_params, batch):
grads = jax.grad(loss_fn)(w, batch)
return jax.tree.map(
lambda w_i, lr_i, g_i: w_i - jax.nn.softplus(lr_i) * g_i,
w, lr_params, grads,
)
return update_fn
Running HyperDistill
from gradhpo.algorithms.online import OnlineHypergradientOptimizer
update_fn = make_update_fn(cross_entropy_loss)
opt = OnlineHypergradientOptimizer(
update_fn=update_fn,
gamma=0.99,
estimation_period=10,
T=20,
)
state = opt.init(w_init, lam_init)
losses = []
def callback(episode, state):
loss, _ = evaluate(state.params, X_val, Y_val)
losses.append(loss)
state = opt.run(
state, M=60,
get_train_batch=train_iter,
get_val_batch=val_iter,
train_loss_fn=bilevel_loss,
val_loss_fn=bilevel_loss,
lr_reptile=1.0,
lr_hyper=3e-3,
callback=callback,
)
Running T1T2 and Greedy
import optax
from gradhpo.algorithms.t1t2 import T1T2Optimizer
from gradhpo.algorithms.greedy import GreedyOptimizer
# T1T2 uses the same update_fn
t1t2_opt = T1T2Optimizer(update_fn=update_fn, gamma=0.9, T=20)
t1t2_state = t1t2_opt.init(w_init, lam_init)
t1t2_state = t1t2_opt.run(
t1t2_state, M=60,
get_train_batch=train_iter,
get_val_batch=val_iter,
train_loss_fn=bilevel_loss,
val_loss_fn=bilevel_loss,
lr_hyper=3e-3,
)
# GreedyOptimizer takes Optax optimizers
greedy_opt = GreedyOptimizer(
inner_optimizer=optax.sgd(0.01),
outer_optimizer=optax.adam(3e-3),
unroll_steps=5,
gamma=0.9,
)
greedy_state = greedy_opt.init(w_init, lam_init)
greedy_state = greedy_opt.run(
greedy_state, M=60,
get_train_batch=train_iter,
get_val_batch=val_iter,
train_loss_fn=bilevel_loss,
val_loss_fn=bilevel_loss,
)
Comparison with Baselines
The baselines use the same interface:
from gradhpo.algorithms.baselines import OneStepOptimizer, FOOptimizer
onestep_opt = OneStepOptimizer(update_fn=update_fn)
onestep_state = onestep_opt.init(w_init, lam_init)
onestep_state = onestep_opt.run(
onestep_state, M=60, T=20,
get_train_batch=train_iter,
get_val_batch=val_iter,
train_loss_fn=bilevel_loss,
val_loss_fn=bilevel_loss,
lr_reptile=1.0,
lr_hyper=3e-3,
)
fo_opt = FOOptimizer(update_fn=update_fn)
fo_state = fo_opt.init(w_init, lam_init)
fo_state = fo_opt.run(
fo_state, M=60, T=20,
get_train_batch=train_iter,
get_val_batch=val_iter,
train_loss_fn=bilevel_loss,
val_loss_fn=bilevel_loss,
lr_hyper=3e-3,
)
Expected Results
With default parameters (M=60, T=20, gamma=0.99):
Method |
last-5 avg loss |
best loss |
|---|---|---|
Fixed LR |
~0.186 |
~0.175 |
FO |
~0.183 |
~0.172 |
One-Step |
~0.171 |
~0.160 |
T1T2 |
~0.168 |
~0.155 |
HyperDistill |
~0.162 |
~0.148 |
Greedy |
~0.165 |
~0.152 |
HyperDistill consistently achieves the best result because it incorporates information from the entire training trajectory via EMA distillation, while One-Step only sees the local effect of the last step. T1T2 and Greedy occupy an intermediate position.
Visualization
python code/demo_hyperdistill.py --plot
This saves a convergence plot to figures/hyperdistill_poc.png.
The plot shows that HyperDistill and One-Step converge faster than Fixed LR,
with HyperDistill reaching a lower validation loss.
For an interactive comparison of all five algorithms, open the notebook:
jupyter notebook code/demo_methods.ipynb
The notebook contains six sections: Demo 1 (HyperDistill), Demo 2 (One-Step), Demo 3 (FO), Demo 4 (T1T2), Demo 5 (Greedy), and Demo 6 — a comparison of all methods on a single plot.