GradHpO: Gradient-Based Hyperparameter Optimization
Short-horizon gradient-based hyperparameter optimization in JAX
Welcome to GradHpO
GradHpO is a JAX library for gradient-based hyperparameter optimization
in bilevel optimization problems. Five algorithms are implemented under
a unified BilevelOptimizer interface:
HyperDistill — online optimization with EMA weight distillation (Lee et al., ICLR 2022).
T1-T2 with DARTS approximation — classic approach using finite-difference hypergradient estimation (Luketina et al., 2016; Liu et al., 2018).
Greedy — generalized greedy method with inner-loop unrolling (Anonymous, ICLR 2025).
FO (First-Order) — first-order baseline using direct gradient only.
One-Step — one-step lookahead baseline (Luketina et al., 2016).
Key Features
Unified API — all algorithms inherit from
BilevelOptimizerwith methodsinit,step,compute_hypergradient, andrun.Arbitrary pytrees — model parameters and hyperparameters can be any nested structure of JAX arrays.
JAX-compatible ``BilevelState`` — the state is registered as a JAX pytree, allowing it to be passed through
jax.jitandjax.vmap.JIT compilation — all five
step()methods are decorated with@partial(jax.jit, static_argnums=(0, 4, 5, 6)).Optax compatibility —
GreedyOptimizeracceptsoptax.GradientTransformationobjects for inner and outer optimizers.Custom step function — other algorithms accept an arbitrary
update_fn(w, lam, batch) -> w_newcallable.
Quick Start
import jax
import optax
from gradhpo import OnlineHypergradientOptimizer, BilevelState
# Define update_fn: one SGD step with per-parameter learning rates
def update_fn(w, lr_params, batch):
grads = jax.grad(train_loss)(w, batch)
return jax.tree.map(
lambda w_i, lr_i, g_i: w_i - jax.nn.softplus(lr_i) * g_i,
w, lr_params, grads,
)
opt = OnlineHypergradientOptimizer(
update_fn=update_fn,
gamma=0.99,
estimation_period=10,
T=20,
)
state = opt.init(w_init, lam_init)
# Main training loop
state = opt.run(
state, M=60,
get_train_batch=train_iter,
get_val_batch=val_iter,
train_loss_fn=loss_fn,
val_loss_fn=loss_fn,
lr_hyper=3e-3,
)