Quick Start

This section shows a minimal working example: training a linear model with regularization coefficient tuning via HyperDistill.

Data Preparation

import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(0)
k1, k2 = jax.random.split(key)

# Synthetic data: 200 training, 100 validation samples
X_train = jax.random.normal(k1, (200, 10))
y_train = jnp.sign(X_train @ jnp.ones(10))

X_val = jax.random.normal(k2, (100, 10))
y_val = jnp.sign(X_val @ jnp.ones(10))

Model Definition

The loss function takes three arguments (params, hyperparams, batch) — this is the unified interface for all library algorithms.

def loss_fn(params, hyperparams, batch):
    """MSE with L2 regularization, where lambda = softplus(hyperparams)."""
    X, y = batch
    pred = X @ params['w']
    mse = jnp.mean((pred - y) ** 2)
    reg = jax.nn.softplus(hyperparams['log_lam']) * jnp.sum(params['w'] ** 2)
    return mse + reg

Initialization

from gradhpo import OnlineHypergradientOptimizer

w_init = {'w': jnp.zeros(10)}
lam_init = {'log_lam': jnp.array(0.0)}

def update_fn(w, lam, batch):
    grads = jax.grad(loss_fn)(w, lam, batch)
    return jax.tree.map(lambda p, g: p - 0.01 * g, w, grads)

opt = OnlineHypergradientOptimizer(
    update_fn=update_fn,
    gamma=0.99,
    estimation_period=10,
    T=20,
)

state = opt.init(w_init, lam_init)

Training

def get_train():
    return (X_train, y_train)

def get_val():
    return (X_val, y_val)

state = opt.run(
    state, M=30,
    get_train_batch=get_train,
    get_val_batch=get_val,
    train_loss_fn=loss_fn,
    val_loss_fn=loss_fn,
    lr_hyper=1e-3,
)

print(f"lambda = {jax.nn.softplus(state.hyperparams['log_lam']):.4f}")

Comparing Multiple Methods

The same interface works for all algorithms. For FOOptimizer and OneStepOptimizer, the run() method requires an explicit T argument (number of inner steps). GreedyOptimizer takes Optax optimizers instead of update_fn and does not require lr_hyper in step() / run().

import optax
from gradhpo import (
    OnlineHypergradientOptimizer,
    T1T2Optimizer,
    GreedyOptimizer,
    FOOptimizer,
    OneStepOptimizer,
)

# Algorithms based on update_fn
methods_update_fn = {
    'FO':           FOOptimizer(update_fn=update_fn),
    'One-Step':     OneStepOptimizer(update_fn=update_fn),
    'HyperDistill': OnlineHypergradientOptimizer(
                        update_fn=update_fn, gamma=0.99,
                        estimation_period=10, T=20),
    'T1T2':         T1T2Optimizer(update_fn=update_fn, gamma=0.9, T=20),
}

for name, opt in methods_update_fn.items():
    st = opt.init(w_init, lam_init)
    # FO and One-Step require explicit T
    extra = {'T': 20} if name in ('FO', 'One-Step') else {}
    st = opt.run(
        st, M=30,
        get_train_batch=get_train,
        get_val_batch=get_val,
        train_loss_fn=loss_fn,
        val_loss_fn=loss_fn,
        lr_hyper=1e-3,
        **extra,
    )
    lam = jax.nn.softplus(st.hyperparams['log_lam'])
    print(f"{name}: lambda = {lam:.4f}")

# GreedyOptimizer uses Optax optimizers
greedy = GreedyOptimizer(
    inner_optimizer=optax.sgd(0.01),
    outer_optimizer=optax.adam(1e-3),
    unroll_steps=5,
    gamma=0.9,
)
gs = greedy.init(w_init, lam_init)
gs = greedy.run(
    gs, M=30,
    get_train_batch=get_train,
    get_val_batch=get_val,
    train_loss_fn=loss_fn,
    val_loss_fn=loss_fn,
)
lam = jax.nn.softplus(gs.hyperparams['log_lam'])
print(f"Greedy: lambda = {lam:.4f}")

A detailed example with result visualization is provided in Tutorial. The full API reference is in API Reference.