Tutorial

This section walks through a detailed example: training a two-layer MLP on a synthetic classification task with per-parameter learning rate tuning via HyperDistill and baselines.

The full script is in code/demo_hyperdistill.py. An interactive notebook covering all five algorithms and their comparison is in code/demo_methods.ipynb.

Task Description

  • Model: MLP 10 → 32 → 5 (Xavier initialization), 517 parameters.

  • Data: 5 classes, Gaussian clusters around fixed centers. The training set contains 20% label noise; the validation set is clean.

  • Hyperparameter \(\lambda\): a vector of 517 values — one learning rate per model parameter, passed through softplus to guarantee positivity.

\[w_{t+1} = w_t - \mathrm{softplus}(\lambda) \odot \nabla L_{\mathrm{train}}(w_t).\]

This choice of hyperparameter is interesting because \(\lambda\) does not appear directly in \(L_{\mathrm{val}}\), so \(g_{\mathrm{FO}} = 0\) and all signal comes from the second-order term.

Model Definition

import jax
import jax.numpy as jnp

def init_mlp(key, in_dim, hidden_dim, out_dim):
    k1, k2 = jax.random.split(key)
    return {
        'w1': jax.random.normal(k1, (in_dim, hidden_dim)) * jnp.sqrt(2.0 / in_dim),
        'b1': jnp.zeros(hidden_dim),
        'w2': jax.random.normal(k2, (hidden_dim, out_dim)) * jnp.sqrt(2.0 / hidden_dim),
        'b2': jnp.zeros(out_dim),
    }

def mlp_forward(params, x):
    h = jax.nn.relu(x @ params['w1'] + params['b1'])
    return h @ params['w2'] + params['b2']

Loss Functions and Inner Step

def cross_entropy_loss(params, batch):
    x, y = batch
    logits = mlp_forward(params, x)
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    return -jnp.mean(jnp.sum(log_probs * y, axis=-1))

# Bilevel interface wrapper: (params, hyperparams, batch) -> scalar
def bilevel_loss(params, hyperparams, batch):
    return cross_entropy_loss(params, batch)

def make_update_fn(loss_fn):
    def update_fn(w, lr_params, batch):
        grads = jax.grad(loss_fn)(w, batch)
        return jax.tree.map(
            lambda w_i, lr_i, g_i: w_i - jax.nn.softplus(lr_i) * g_i,
            w, lr_params, grads,
        )
    return update_fn

Running HyperDistill

from gradhpo.algorithms.online import OnlineHypergradientOptimizer

update_fn = make_update_fn(cross_entropy_loss)

opt = OnlineHypergradientOptimizer(
    update_fn=update_fn,
    gamma=0.99,
    estimation_period=10,
    T=20,
)

state = opt.init(w_init, lam_init)

losses = []

def callback(episode, state):
    loss, _ = evaluate(state.params, X_val, Y_val)
    losses.append(loss)

state = opt.run(
    state, M=60,
    get_train_batch=train_iter,
    get_val_batch=val_iter,
    train_loss_fn=bilevel_loss,
    val_loss_fn=bilevel_loss,
    lr_reptile=1.0,
    lr_hyper=3e-3,
    callback=callback,
)

Running T1T2 and Greedy

import optax
from gradhpo.algorithms.t1t2 import T1T2Optimizer
from gradhpo.algorithms.greedy import GreedyOptimizer

# T1T2 uses the same update_fn
t1t2_opt = T1T2Optimizer(update_fn=update_fn, gamma=0.9, T=20)
t1t2_state = t1t2_opt.init(w_init, lam_init)
t1t2_state = t1t2_opt.run(
    t1t2_state, M=60,
    get_train_batch=train_iter,
    get_val_batch=val_iter,
    train_loss_fn=bilevel_loss,
    val_loss_fn=bilevel_loss,
    lr_hyper=3e-3,
)

# GreedyOptimizer takes Optax optimizers
greedy_opt = GreedyOptimizer(
    inner_optimizer=optax.sgd(0.01),
    outer_optimizer=optax.adam(3e-3),
    unroll_steps=5,
    gamma=0.9,
)
greedy_state = greedy_opt.init(w_init, lam_init)
greedy_state = greedy_opt.run(
    greedy_state, M=60,
    get_train_batch=train_iter,
    get_val_batch=val_iter,
    train_loss_fn=bilevel_loss,
    val_loss_fn=bilevel_loss,
)

Comparison with Baselines

The baselines use the same interface:

from gradhpo.algorithms.baselines import OneStepOptimizer, FOOptimizer

onestep_opt = OneStepOptimizer(update_fn=update_fn)
onestep_state = onestep_opt.init(w_init, lam_init)
onestep_state = onestep_opt.run(
    onestep_state, M=60, T=20,
    get_train_batch=train_iter,
    get_val_batch=val_iter,
    train_loss_fn=bilevel_loss,
    val_loss_fn=bilevel_loss,
    lr_reptile=1.0,
    lr_hyper=3e-3,
)

fo_opt = FOOptimizer(update_fn=update_fn)
fo_state = fo_opt.init(w_init, lam_init)
fo_state = fo_opt.run(
    fo_state, M=60, T=20,
    get_train_batch=train_iter,
    get_val_batch=val_iter,
    train_loss_fn=bilevel_loss,
    val_loss_fn=bilevel_loss,
    lr_hyper=3e-3,
)

Expected Results

With default parameters (M=60, T=20, gamma=0.99):

Method

last-5 avg loss

best loss

Fixed LR

~0.186

~0.175

FO

~0.183

~0.172

One-Step

~0.171

~0.160

T1T2

~0.168

~0.155

HyperDistill

~0.162

~0.148

Greedy

~0.165

~0.152

HyperDistill consistently achieves the best result because it incorporates information from the entire training trajectory via EMA distillation, while One-Step only sees the local effect of the last step. T1T2 and Greedy occupy an intermediate position.

Visualization

python code/demo_hyperdistill.py --plot

This saves a convergence plot to figures/hyperdistill_poc.png. The plot shows that HyperDistill and One-Step converge faster than Fixed LR, with HyperDistill reaching a lower validation loss.

For an interactive comparison of all five algorithms, open the notebook:

jupyter notebook code/demo_methods.ipynb

The notebook contains six sections: Demo 1 (HyperDistill), Demo 2 (One-Step), Demo 3 (FO), Demo 4 (T1T2), Demo 5 (Greedy), and Demo 6 — a comparison of all methods on a single plot.