Overview
Problem Formulation
The bilevel optimization problem in the machine learning context is formulated as a nested optimization problem:
where \(w\) are the model parameters (inner level), \(\lambda\) are the hyperparameters (outer level), and \(L_{\mathrm{train}}\), \(L_{\mathrm{val}}\) are the training and validation loss functions respectively.
In practice, the solution \(w^*(\lambda)\) is not available in closed form, so it is approximated by a finite number of optimization steps. Denoting one step of the inner optimizer as \(\Phi(w, \lambda; D)\), after \(T\) steps we obtain \(w_T\), which depends on \(\lambda\) through the entire trajectory.
Hypergradient
The full hypergradient \(\mathrm{d}L_{\mathrm{val}} / \mathrm{d}\lambda\) is decomposed via the chain rule:
where \(\alpha_t = \nabla_{w_t} L_{\mathrm{val}}(w_t)\), \(A_s = \partial \Phi / \partial w\), and \(B_t = \partial \Phi / \partial \lambda\).
Computing the full sum requires backpropagation through all \(T\) steps, which is expensive in memory and time. GradHpO implements several short-horizon approximations of this sum.
Implemented Algorithms
HyperDistill
The online method with hypergradient distillation (Lee et al., ICLR 2022) approximates the full second-order term via an EMA point \(w^*_t\):
The hypergradient at step \(t\):
where \(v_t = \alpha_t \cdot \partial\Phi(w^*_t, \lambda) / \partial\lambda\), and the scalar \(\theta\) is estimated periodically via a DrMAD-style backward pass (Algorithm 4 in the paper).
Class: OnlineHypergradientOptimizer.
T1-T2 with DARTS
The T1-T2 algorithm (Luketina et al., 2016) separates the parameter and hyperparameter update steps. In our implementation, \(B_t\) is computed using the DARTS finite-difference approximation (Liu et al., 2018):
This avoids explicit differentiation through the optimizer.
Class: T1T2Optimizer.
Greedy
The generalized greedy approach (Anonymous, ICLR 2025) accounts for \(T\) steps with exponential decay:
The parameter \(\gamma \in (0, 1]\) controls the contribution of early
steps. Unlike other algorithms, GreedyOptimizer accepts
inner_optimizer and outer_optimizer as optax.GradientTransformation
objects instead of a custom update_fn.
Class: GreedyOptimizer.
Baselines
FO (First-Order): uses only the direct gradient \(g_{\mathrm{FO}} = \partial L_{\mathrm{val}} / \partial \lambda\). If \(\lambda\) does not appear directly in \(L_{\mathrm{val}}\), the update is zero.
One-Step: accounts for \(B_t\) only at the last step, \(g = g_{\mathrm{FO}} + \alpha_T \cdot B_T\). Equivalent to HyperDistill with \(\gamma = 0\).
Classes: FOOptimizer,
OneStepOptimizer.
Library Architecture
All algorithms inherit from BilevelOptimizer and implement three methods:
class BilevelOptimizer(ABC):
def init(self, params, hyperparams) -> BilevelState: ...
def step(self, state, train_batch, val_batch,
train_loss_fn, val_loss_fn, lr_hyper) -> BilevelState: ...
def compute_hypergradient(self, state, train_batch, val_batch,
train_loss_fn, val_loss_fn) -> PyTree: ...
Note
The step() signature of GreedyOptimizer does not include
lr_hyper (the outer optimizer step size is set via outer_optimizer
at construction time). All other algorithms require lr_hyper as a
mandatory argument.
The optimization state is stored in BilevelState — an immutable
container with fields params, hyperparams, inner_opt_state,
outer_opt_state, step, and metadata.
BilevelState is registered as a JAX pytree via
jax.tree_util.register_pytree_node, allowing it to be passed directly
to jax.jit, jax.vmap, and other JAX transformations.
All five step() methods are decorated with
@partial(jax.jit, static_argnums=(0, 4, 5, 6)), where the static
arguments are self (0), train_loss_fn (4), val_loss_fn (5),
and lr_hyper (6).