Overview

Problem Formulation

The bilevel optimization problem in the machine learning context is formulated as a nested optimization problem:

\[\min_{\lambda}\; L_{\mathrm{val}}\bigl(w^*(\lambda),\,\lambda\bigr), \qquad w^*(\lambda) = \arg\min_{w}\; L_{\mathrm{train}}(w,\,\lambda),\]

where \(w\) are the model parameters (inner level), \(\lambda\) are the hyperparameters (outer level), and \(L_{\mathrm{train}}\), \(L_{\mathrm{val}}\) are the training and validation loss functions respectively.

In practice, the solution \(w^*(\lambda)\) is not available in closed form, so it is approximated by a finite number of optimization steps. Denoting one step of the inner optimizer as \(\Phi(w, \lambda; D)\), after \(T\) steps we obtain \(w_T\), which depends on \(\lambda\) through the entire trajectory.

Hypergradient

The full hypergradient \(\mathrm{d}L_{\mathrm{val}} / \mathrm{d}\lambda\) is decomposed via the chain rule:

\[\frac{\mathrm{d}L_{\mathrm{val}}}{\mathrm{d}\lambda} = \underbrace{\frac{\partial L_{\mathrm{val}}}{\partial \lambda}}_{g_{\mathrm{FO}}} + \sum_{t=1}^{T} \underbrace{\alpha_t \cdot \prod_{s=t+1}^{T} A_s}_{} \cdot B_t,\]

where \(\alpha_t = \nabla_{w_t} L_{\mathrm{val}}(w_t)\), \(A_s = \partial \Phi / \partial w\), and \(B_t = \partial \Phi / \partial \lambda\).

Computing the full sum requires backpropagation through all \(T\) steps, which is expensive in memory and time. GradHpO implements several short-horizon approximations of this sum.

Implemented Algorithms

HyperDistill

The online method with hypergradient distillation (Lee et al., ICLR 2022) approximates the full second-order term via an EMA point \(w^*_t\):

\[w^*_t = p_t \cdot w^*_{t-1} + (1 - p_t) \cdot w_{t-1}, \qquad p_t = \frac{\gamma - \gamma^t}{1 - \gamma^t}.\]

The hypergradient at step \(t\):

\[g_t = g_{\mathrm{FO}} + \theta \cdot \frac{1 - \gamma^t}{1 - \gamma} \cdot v_t,\]

where \(v_t = \alpha_t \cdot \partial\Phi(w^*_t, \lambda) / \partial\lambda\), and the scalar \(\theta\) is estimated periodically via a DrMAD-style backward pass (Algorithm 4 in the paper).

Class: OnlineHypergradientOptimizer.

T1-T2 with DARTS

The T1-T2 algorithm (Luketina et al., 2016) separates the parameter and hyperparameter update steps. In our implementation, \(B_t\) is computed using the DARTS finite-difference approximation (Liu et al., 2018):

\[B_t \approx \frac{\Phi(w, \lambda + \varepsilon e_i) - \Phi(w, \lambda - \varepsilon e_i)} {2\varepsilon}.\]

This avoids explicit differentiation through the optimizer.

Class: T1T2Optimizer.

Greedy

The generalized greedy approach (Anonymous, ICLR 2025) accounts for \(T\) steps with exponential decay:

\[\hat{d}_\lambda = \nabla_\lambda L_{\mathrm{val}}(w_T) + \sum_{t=1}^{T} \gamma^{T-t}\, \nabla_{w_t} L_{\mathrm{val}}(w_t) \cdot B_t.\]

The parameter \(\gamma \in (0, 1]\) controls the contribution of early steps. Unlike other algorithms, GreedyOptimizer accepts inner_optimizer and outer_optimizer as optax.GradientTransformation objects instead of a custom update_fn.

Class: GreedyOptimizer.

Baselines

  • FO (First-Order): uses only the direct gradient \(g_{\mathrm{FO}} = \partial L_{\mathrm{val}} / \partial \lambda\). If \(\lambda\) does not appear directly in \(L_{\mathrm{val}}\), the update is zero.

  • One-Step: accounts for \(B_t\) only at the last step, \(g = g_{\mathrm{FO}} + \alpha_T \cdot B_T\). Equivalent to HyperDistill with \(\gamma = 0\).

Classes: FOOptimizer, OneStepOptimizer.

Library Architecture

All algorithms inherit from BilevelOptimizer and implement three methods:

class BilevelOptimizer(ABC):
    def init(self, params, hyperparams) -> BilevelState: ...
    def step(self, state, train_batch, val_batch,
             train_loss_fn, val_loss_fn, lr_hyper) -> BilevelState: ...
    def compute_hypergradient(self, state, train_batch, val_batch,
                              train_loss_fn, val_loss_fn) -> PyTree: ...

Note

The step() signature of GreedyOptimizer does not include lr_hyper (the outer optimizer step size is set via outer_optimizer at construction time). All other algorithms require lr_hyper as a mandatory argument.

The optimization state is stored in BilevelState — an immutable container with fields params, hyperparams, inner_opt_state, outer_opt_state, step, and metadata.

BilevelState is registered as a JAX pytree via jax.tree_util.register_pytree_node, allowing it to be passed directly to jax.jit, jax.vmap, and other JAX transformations.

All five step() methods are decorated with @partial(jax.jit, static_argnums=(0, 4, 5, 6)), where the static arguments are self (0), train_loss_fn (4), val_loss_fn (5), and lr_hyper (6).