Algorithms (algorithms)

OnlineHypergradientOptimizer

OnlineHypergradientOptimizer: HyperDistill algorithm (Lee et al., ICLR 2022).

Implements online hyperparameter meta-learning with hypergradient distillation.

class gradhpo.algorithms.online.OnlineHypergradientOptimizer(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, gamma: float = 0.99, estimation_period: int = 50, T: int = 20, update_fn: Callable | None = None)[source]

Bases: BilevelOptimizer

Online optimization with hypergradient distillation.

Implements Algorithm 3 and Algorithm 4 from Lee et al. (ICLR 2022).

Variables:
  • inner_optimizer – Optax optimizer for parameters (optional if update_fn given).

  • outer_optimizer – Optax optimizer for hyperparameters (optional, falls back to SGD).

  • gamma – EMA decay factor in [0, 1].

  • estimation_period – Re-estimate theta every N episodes.

  • T – Number of inner steps per episode.

  • update_fn – Optional custom inner step function Phi(w, lam, batch) -> w_new.

__init__(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, gamma: float = 0.99, estimation_period: int = 50, T: int = 20, update_fn: Callable | None = None)[source]

Initialize Online Hypergradient optimizer.

Parameters:
  • inner_optimizer – Optax optimizer for parameters.

  • outer_optimizer – Optax optimizer for hyperparameters.

  • gamma – EMA decay factor.

  • estimation_period – Re-estimate theta every N episodes.

  • T – Inner steps per episode.

  • update_fn – Custom inner step Phi(w, lam, batch) -> w_new. If provided, inner_optimizer is not used.

init(params: Any, hyperparams: Any) BilevelState[source]

Initialize state with hypergradient memory.

Parameters:
  • params – Initial model parameters.

  • hyperparams – Initial hyperparameters.

Returns:

Initial BilevelState.

compute_hypergradient(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) Any[source]

Compute hypergradient w.r.t. hyperparameters.

Uses the HyperDistill approximation: g = g_FO + theta * pi_t * v_t.

Parameters:
  • state – Current state (must have w_star, theta in metadata).

  • train_batch – Training batch.

  • val_batch – Validation batch.

  • train_loss_fn – Training loss function.

  • val_loss_fn – Validation loss function.

Returns:

Hypergradient with same structure as hyperparams.

step(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_hyper: float | None = None) BilevelState[source]

Perform one online optimization step with distillation.

Parameters:
  • state – Current bilevel state.

  • train_batch – Training data batch.

  • val_batch – Validation data batch.

  • train_loss_fn – Training loss function.

  • val_loss_fn – Validation loss function.

  • lr_hyper – Manual learning rate (used when outer_optimizer is None).

Returns:

Updated state.

estimate_theta(state: BilevelState, train_batches: List[Any], val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) float[source]

Algorithm 4: estimate the scaling parameter theta.

Runs a forward pass (T inner steps), then a DrMAD-style backward pass to collect samples (x_s, y_s), and fits theta = (x^T y) / (x^T x).

Parameters:
  • state – Current state (uses params and hyperparams).

  • train_batches – List of T training batches.

  • val_batch – Validation batch.

  • train_loss_fn – Training loss function.

  • val_loss_fn – Validation loss function.

Returns:

Estimated linear scaling parameter.

Return type:

theta

run(state: BilevelState, M: int, get_train_batch: Callable, get_val_batch: Callable, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_reptile: float = 1.0, lr_hyper: float | None = None, callback: Callable | None = None) BilevelState[source]

Full HyperDistill training loop (Algorithm 3).

Runs M episodes, each with T inner steps and a Reptile update.

Parameters:
  • state – Initial state (from init()).

  • M – Number of outer episodes.

  • get_train_batch – Callable returning a training batch.

  • get_val_batch – Callable returning a validation batch.

  • train_loss_fn – Training loss function.

  • val_loss_fn – Validation loss function.

  • lr_reptile – Reptile learning rate for weight initialization.

  • lr_hyper – Manual hyper LR (used when outer_optimizer is None).

  • callback – Optional callback(episode, state).

Returns:

Final BilevelState.

T1T2Optimizer

T1-T2 optimizer with numerical DARTS approximation.

Implements the T1-T2 algorithm from Luketina et al. (2016) with numerical DARTS approximation as described in Liu et al. (2018).

class gradhpo.algorithms.t1t2.T1T2Optimizer(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, gamma: float = 0.9, T: int = 20, update_fn: Callable | None = None, eps: float = 0.0001)[source]

Bases: BilevelOptimizer

T1-T2 optimizer with numerical DARTS approximation.

Implements the T1-T2 algorithm with a fixed horizon length and uses numerical DARTS approximation for computing the hypergradient.

Variables:
  • inner_optimizer – Optax optimizer for parameters (optional if update_fn given).

  • outer_optimizer – Optax optimizer for hyperparameters (optional).

  • gamma – EMA decay factor in [0, 1] for weight distillation.

  • T – Number of inner steps per episode.

  • update_fn – Optional custom inner step function Phi(w, lam, batch) -> w_new.

  • eps – Small epsilon for numerical differentiation in DARTS approximation.

__init__(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, gamma: float = 0.9, T: int = 20, update_fn: Callable | None = None, eps: float = 0.0001)[source]

Initialize T1-T2 optimizer with numerical DARTS approximation.

Parameters:
  • inner_optimizer – Optax optimizer for parameters.

  • outer_optimizer – Optax optimizer for hyperparameters.

  • gamma – EMA decay factor for weight distillation.

  • T – Inner steps per episode.

  • update_fn – Custom inner step Phi(w, lam, batch) -> w_new. If provided, inner_optimizer is not used.

  • eps – Epsilon for numerical differentiation in DARTS approximation.

init(params: Any, hyperparams: Any) BilevelState[source]

Initialize state with hypergradient memory.

Parameters:
  • params – Initial model parameters.

  • hyperparams – Initial hyperparameters.

Returns:

Initial BilevelState.

compute_hypergradient(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) Any[source]

Compute hypergradient w.r.t. hyperparameters using numerical DARTS approximation.

Uses the T1-T2 approximation with numerical DARTS: g = g_FO + sum_{t=1}^T (alpha_t @ dPhi(w_{t-1}, lam; D_t) / dlam)

Parameters:
  • state – Current state (must have w_star in metadata).

  • train_batch – Training batch.

  • val_batch – Validation batch.

  • train_loss_fn – Training loss function.

  • val_loss_fn – Validation loss function.

Returns:

Hypergradient with same structure as hyperparams.

step(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_hyper: float | None = None) BilevelState[source]

Perform one T1-T2 optimization step with numerical DARTS approximation.

Parameters:
  • state – Current bilevel state.

  • train_batch – Training data batch.

  • val_batch – Validation data batch.

  • train_loss_fn – Training loss function.

  • val_loss_fn – Validation loss function.

  • lr_hyper – Manual learning rate (used when outer_optimizer is None).

Returns:

Updated state.

run(state: BilevelState, M: int, get_train_batch: Callable, get_val_batch: Callable, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_reptile: float = 1.0, lr_hyper: float | None = None, callback: Callable | None = None) BilevelState[source]

Full T1-T2 training loop.

Runs M episodes, each with T inner steps and a Reptile update.

Parameters:
  • state – Initial state (from init()).

  • M – Number of outer episodes.

  • get_train_batch – Callable returning a training batch.

  • get_val_batch – Callable returning a validation batch.

  • train_loss_fn – Training loss function.

  • val_loss_fn – Validation loss function.

  • lr_reptile – Reptile learning rate for weight initialization.

  • lr_hyper – Manual hyper LR (used when outer_optimizer is None).

  • callback – Optional callback(episode, state).

Returns:

Final BilevelState.

GreedyOptimizer

class gradhpo.algorithms.greedy.GreedyOptimizer(inner_optimizer: GradientTransformation, outer_optimizer: GradientTransformation, unroll_steps: int = 1, gamma: float = 0.9)[source]

Bases: BilevelOptimizer

Generalized greedy gradient-based hyperparameter optimization.

This implements the approximation from Eq. (6) of the attached paper:

d_hat_alpha =

∇_alpha L_val(w_T, alpha) + sum_{t=1..T} gamma^(T-t) * ∇_{w_t} L_val(w_t, alpha) * B_t

where B_t = ∂Φ(w_{t-1}, alpha) / ∂alpha and Φ is one inner optimizer step.

Notes

  • unroll_steps corresponds to the horizon T.

  • gamma controls how strongly earlier greedy terms contribute.

  • A single train_batch / val_batch is reused across the unrolled inner steps, which matches the current library template API.

__init__(inner_optimizer: GradientTransformation, outer_optimizer: GradientTransformation, unroll_steps: int = 1, gamma: float = 0.9)[source]

Initialize bilevel optimizer.

init(params: Any, hyperparams: Any) BilevelState[source]

Initialize Greedy optimization state.

step(state: BilevelState, train_batch: DataBatch, val_batch: DataBatch, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) BilevelState[source]

Perform one greedy optimization step.

Algorithm

  1. Unroll unroll_steps inner updates on train loss.

  2. Compute greedy hypergradient as a weighted sum of local greedy terms.

  3. Update hyperparameters using the outer optimizer.

  4. Keep the unrolled parameters / inner optimizer state.

compute_hypergradient(state: BilevelState, train_batch: DataBatch, val_batch: DataBatch, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) Any[source]

Compute hypergradient by unrolling inner optimization.

run(state: BilevelState, M: int, get_train_batch: Callable, get_val_batch: Callable, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_reptile: float = 1.0, lr_hyper: float | None = None, callback: Callable | None = None) BilevelState[source]

Full Greedy training loop.

Runs M outer episodes. Each episode performs unroll_steps inner gradient steps (the greedy unroll horizon) and then updates hyperparameters with the greedy hypergradient. A Reptile update is applied to the weight initialisation after every episode.

Parameters:
  • state – Initial state (from init()).

  • M – Number of outer episodes.

  • get_train_batch – Callable returning a training batch.

  • get_val_batch – Callable returning a validation batch.

  • train_loss_fn – Training loss function (params, hyperparams, batch).

  • val_loss_fn – Validation loss function (params, hyperparams, batch).

  • lr_reptile – Reptile learning rate for weight initialisation.

  • lr_hyper – Manual hyper LR used when outer_optimizer is None.

  • callback – Optional callback(episode, state).

Returns:

Final BilevelState.

Baselines

Baseline bilevel optimizers: first-order and one-step lookahead.

class gradhpo.algorithms.baselines.FOOptimizer(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, update_fn: Callable | None = None)[source]

Bases: BilevelOptimizer

First-order baseline: ignores the second-order term entirely.

Only uses g_FO = dL_val/dlam (direct gradient). If val_loss does not depend on hyperparams, no update is made to lambda.

Variables:
  • inner_optimizer – Optax optimizer for parameters (optional if update_fn given).

  • outer_optimizer – Optax optimizer for hyperparameters (optional).

  • update_fn – Optional custom inner step Phi(w, lam, batch) -> w_new.

__init__(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, update_fn: Callable | None = None)[source]

Initialize bilevel optimizer.

init(params: Any, hyperparams: Any) BilevelState[source]

Initialize the bilevel optimization state.

Parameters:
  • params – Initial model parameters.

  • hyperparams – Initial hyperparameters.

Returns:

Initial optimization state.

compute_hypergradient(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) Any[source]

First-order hypergradient: g_FO = dL_val/dlam.

step(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_hyper: float | None = None) BilevelState[source]
run(state: BilevelState, M: int, T: int, get_train_batch: Callable, get_val_batch: Callable, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_reptile: float = 1.0, lr_hyper: float | None = None, callback: Callable | None = None) BilevelState[source]
class gradhpo.algorithms.baselines.OneStepOptimizer(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, update_fn: Callable | None = None)[source]

Bases: BilevelOptimizer

One-step lookahead baseline (Luketina et al., 2016).

Computes hypergradient using only the last step (short horizon, gamma=0).

Variables:
  • inner_optimizer – Optax optimizer for parameters (optional if update_fn given).

  • outer_optimizer – Optax optimizer for hyperparameters (optional).

  • update_fn – Optional custom inner step Phi(w, lam, batch) -> w_new.

__init__(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, update_fn: Callable | None = None)[source]

Initialize bilevel optimizer.

init(params: Any, hyperparams: Any) BilevelState[source]

Initialize the bilevel optimization state.

Parameters:
  • params – Initial model parameters.

  • hyperparams – Initial hyperparameters.

Returns:

Initial optimization state.

compute_hypergradient(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) Any[source]

One-step hypergradient: g_FO + alpha @ B_t.

step(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_hyper: float | None = None) BilevelState[source]
run(state: BilevelState, M: int, T: int, get_train_batch: Callable, get_val_batch: Callable, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_reptile: float = 1.0, lr_hyper: float | None = None, callback: Callable | None = None) BilevelState[source]