Algorithms (algorithms)
OnlineHypergradientOptimizer
OnlineHypergradientOptimizer: HyperDistill algorithm (Lee et al., ICLR 2022).
Implements online hyperparameter meta-learning with hypergradient distillation.
- class gradhpo.algorithms.online.OnlineHypergradientOptimizer(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, gamma: float = 0.99, estimation_period: int = 50, T: int = 20, update_fn: Callable | None = None)[source]
Bases:
BilevelOptimizerOnline optimization with hypergradient distillation.
Implements Algorithm 3 and Algorithm 4 from Lee et al. (ICLR 2022).
- Variables:
inner_optimizer – Optax optimizer for parameters (optional if update_fn given).
outer_optimizer – Optax optimizer for hyperparameters (optional, falls back to SGD).
gamma – EMA decay factor in [0, 1].
estimation_period – Re-estimate theta every N episodes.
T – Number of inner steps per episode.
update_fn – Optional custom inner step function Phi(w, lam, batch) -> w_new.
- __init__(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, gamma: float = 0.99, estimation_period: int = 50, T: int = 20, update_fn: Callable | None = None)[source]
Initialize Online Hypergradient optimizer.
- Parameters:
inner_optimizer – Optax optimizer for parameters.
outer_optimizer – Optax optimizer for hyperparameters.
gamma – EMA decay factor.
estimation_period – Re-estimate theta every N episodes.
T – Inner steps per episode.
update_fn – Custom inner step Phi(w, lam, batch) -> w_new. If provided, inner_optimizer is not used.
- init(params: Any, hyperparams: Any) BilevelState[source]
Initialize state with hypergradient memory.
- Parameters:
params – Initial model parameters.
hyperparams – Initial hyperparameters.
- Returns:
Initial BilevelState.
- compute_hypergradient(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) Any[source]
Compute hypergradient w.r.t. hyperparameters.
Uses the HyperDistill approximation: g = g_FO + theta * pi_t * v_t.
- Parameters:
state – Current state (must have w_star, theta in metadata).
train_batch – Training batch.
val_batch – Validation batch.
train_loss_fn – Training loss function.
val_loss_fn – Validation loss function.
- Returns:
Hypergradient with same structure as hyperparams.
- step(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_hyper: float | None = None) BilevelState[source]
Perform one online optimization step with distillation.
- Parameters:
state – Current bilevel state.
train_batch – Training data batch.
val_batch – Validation data batch.
train_loss_fn – Training loss function.
val_loss_fn – Validation loss function.
lr_hyper – Manual learning rate (used when outer_optimizer is None).
- Returns:
Updated state.
- estimate_theta(state: BilevelState, train_batches: List[Any], val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) float[source]
Algorithm 4: estimate the scaling parameter theta.
Runs a forward pass (T inner steps), then a DrMAD-style backward pass to collect samples (x_s, y_s), and fits theta = (x^T y) / (x^T x).
- Parameters:
state – Current state (uses params and hyperparams).
train_batches – List of T training batches.
val_batch – Validation batch.
train_loss_fn – Training loss function.
val_loss_fn – Validation loss function.
- Returns:
Estimated linear scaling parameter.
- Return type:
theta
- run(state: BilevelState, M: int, get_train_batch: Callable, get_val_batch: Callable, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_reptile: float = 1.0, lr_hyper: float | None = None, callback: Callable | None = None) BilevelState[source]
Full HyperDistill training loop (Algorithm 3).
Runs M episodes, each with T inner steps and a Reptile update.
- Parameters:
state – Initial state (from init()).
M – Number of outer episodes.
get_train_batch – Callable returning a training batch.
get_val_batch – Callable returning a validation batch.
train_loss_fn – Training loss function.
val_loss_fn – Validation loss function.
lr_reptile – Reptile learning rate for weight initialization.
lr_hyper – Manual hyper LR (used when outer_optimizer is None).
callback – Optional callback(episode, state).
- Returns:
Final BilevelState.
T1T2Optimizer
T1-T2 optimizer with numerical DARTS approximation.
Implements the T1-T2 algorithm from Luketina et al. (2016) with numerical DARTS approximation as described in Liu et al. (2018).
- class gradhpo.algorithms.t1t2.T1T2Optimizer(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, gamma: float = 0.9, T: int = 20, update_fn: Callable | None = None, eps: float = 0.0001)[source]
Bases:
BilevelOptimizerT1-T2 optimizer with numerical DARTS approximation.
Implements the T1-T2 algorithm with a fixed horizon length and uses numerical DARTS approximation for computing the hypergradient.
- Variables:
inner_optimizer – Optax optimizer for parameters (optional if update_fn given).
outer_optimizer – Optax optimizer for hyperparameters (optional).
gamma – EMA decay factor in [0, 1] for weight distillation.
T – Number of inner steps per episode.
update_fn – Optional custom inner step function Phi(w, lam, batch) -> w_new.
eps – Small epsilon for numerical differentiation in DARTS approximation.
- __init__(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, gamma: float = 0.9, T: int = 20, update_fn: Callable | None = None, eps: float = 0.0001)[source]
Initialize T1-T2 optimizer with numerical DARTS approximation.
- Parameters:
inner_optimizer – Optax optimizer for parameters.
outer_optimizer – Optax optimizer for hyperparameters.
gamma – EMA decay factor for weight distillation.
T – Inner steps per episode.
update_fn – Custom inner step Phi(w, lam, batch) -> w_new. If provided, inner_optimizer is not used.
eps – Epsilon for numerical differentiation in DARTS approximation.
- init(params: Any, hyperparams: Any) BilevelState[source]
Initialize state with hypergradient memory.
- Parameters:
params – Initial model parameters.
hyperparams – Initial hyperparameters.
- Returns:
Initial BilevelState.
- compute_hypergradient(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) Any[source]
Compute hypergradient w.r.t. hyperparameters using numerical DARTS approximation.
Uses the T1-T2 approximation with numerical DARTS: g = g_FO + sum_{t=1}^T (alpha_t @ dPhi(w_{t-1}, lam; D_t) / dlam)
- Parameters:
state – Current state (must have w_star in metadata).
train_batch – Training batch.
val_batch – Validation batch.
train_loss_fn – Training loss function.
val_loss_fn – Validation loss function.
- Returns:
Hypergradient with same structure as hyperparams.
- step(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_hyper: float | None = None) BilevelState[source]
Perform one T1-T2 optimization step with numerical DARTS approximation.
- Parameters:
state – Current bilevel state.
train_batch – Training data batch.
val_batch – Validation data batch.
train_loss_fn – Training loss function.
val_loss_fn – Validation loss function.
lr_hyper – Manual learning rate (used when outer_optimizer is None).
- Returns:
Updated state.
- run(state: BilevelState, M: int, get_train_batch: Callable, get_val_batch: Callable, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_reptile: float = 1.0, lr_hyper: float | None = None, callback: Callable | None = None) BilevelState[source]
Full T1-T2 training loop.
Runs M episodes, each with T inner steps and a Reptile update.
- Parameters:
state – Initial state (from init()).
M – Number of outer episodes.
get_train_batch – Callable returning a training batch.
get_val_batch – Callable returning a validation batch.
train_loss_fn – Training loss function.
val_loss_fn – Validation loss function.
lr_reptile – Reptile learning rate for weight initialization.
lr_hyper – Manual hyper LR (used when outer_optimizer is None).
callback – Optional callback(episode, state).
- Returns:
Final BilevelState.
GreedyOptimizer
- class gradhpo.algorithms.greedy.GreedyOptimizer(inner_optimizer: GradientTransformation, outer_optimizer: GradientTransformation, unroll_steps: int = 1, gamma: float = 0.9)[source]
Bases:
BilevelOptimizerGeneralized greedy gradient-based hyperparameter optimization.
This implements the approximation from Eq. (6) of the attached paper:
- d_hat_alpha =
∇_alpha L_val(w_T, alpha) + sum_{t=1..T} gamma^(T-t) * ∇_{w_t} L_val(w_t, alpha) * B_t
where B_t = ∂Φ(w_{t-1}, alpha) / ∂alpha and Φ is one inner optimizer step.
Notes
unroll_steps corresponds to the horizon T.
gamma controls how strongly earlier greedy terms contribute.
A single train_batch / val_batch is reused across the unrolled inner steps, which matches the current library template API.
- __init__(inner_optimizer: GradientTransformation, outer_optimizer: GradientTransformation, unroll_steps: int = 1, gamma: float = 0.9)[source]
Initialize bilevel optimizer.
- init(params: Any, hyperparams: Any) BilevelState[source]
Initialize Greedy optimization state.
- step(state: BilevelState, train_batch: DataBatch, val_batch: DataBatch, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) BilevelState[source]
Perform one greedy optimization step.
Algorithm
Unroll unroll_steps inner updates on train loss.
Compute greedy hypergradient as a weighted sum of local greedy terms.
Update hyperparameters using the outer optimizer.
Keep the unrolled parameters / inner optimizer state.
- compute_hypergradient(state: BilevelState, train_batch: DataBatch, val_batch: DataBatch, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) Any[source]
Compute hypergradient by unrolling inner optimization.
- run(state: BilevelState, M: int, get_train_batch: Callable, get_val_batch: Callable, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_reptile: float = 1.0, lr_hyper: float | None = None, callback: Callable | None = None) BilevelState[source]
Full Greedy training loop.
Runs M outer episodes. Each episode performs unroll_steps inner gradient steps (the greedy unroll horizon) and then updates hyperparameters with the greedy hypergradient. A Reptile update is applied to the weight initialisation after every episode.
- Parameters:
state – Initial state (from init()).
M – Number of outer episodes.
get_train_batch – Callable returning a training batch.
get_val_batch – Callable returning a validation batch.
train_loss_fn – Training loss function (params, hyperparams, batch).
val_loss_fn – Validation loss function (params, hyperparams, batch).
lr_reptile – Reptile learning rate for weight initialisation.
lr_hyper – Manual hyper LR used when outer_optimizer is None.
callback – Optional callback(episode, state).
- Returns:
Final BilevelState.
Baselines
Baseline bilevel optimizers: first-order and one-step lookahead.
- class gradhpo.algorithms.baselines.FOOptimizer(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, update_fn: Callable | None = None)[source]
Bases:
BilevelOptimizerFirst-order baseline: ignores the second-order term entirely.
Only uses g_FO = dL_val/dlam (direct gradient). If val_loss does not depend on hyperparams, no update is made to lambda.
- Variables:
inner_optimizer – Optax optimizer for parameters (optional if update_fn given).
outer_optimizer – Optax optimizer for hyperparameters (optional).
update_fn – Optional custom inner step Phi(w, lam, batch) -> w_new.
- __init__(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, update_fn: Callable | None = None)[source]
Initialize bilevel optimizer.
- init(params: Any, hyperparams: Any) BilevelState[source]
Initialize the bilevel optimization state.
- Parameters:
params – Initial model parameters.
hyperparams – Initial hyperparameters.
- Returns:
Initial optimization state.
- compute_hypergradient(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) Any[source]
First-order hypergradient: g_FO = dL_val/dlam.
- step(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_hyper: float | None = None) BilevelState[source]
- run(state: BilevelState, M: int, T: int, get_train_batch: Callable, get_val_batch: Callable, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_reptile: float = 1.0, lr_hyper: float | None = None, callback: Callable | None = None) BilevelState[source]
- class gradhpo.algorithms.baselines.OneStepOptimizer(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, update_fn: Callable | None = None)[source]
Bases:
BilevelOptimizerOne-step lookahead baseline (Luketina et al., 2016).
Computes hypergradient using only the last step (short horizon, gamma=0).
- Variables:
inner_optimizer – Optax optimizer for parameters (optional if update_fn given).
outer_optimizer – Optax optimizer for hyperparameters (optional).
update_fn – Optional custom inner step Phi(w, lam, batch) -> w_new.
- __init__(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None, update_fn: Callable | None = None)[source]
Initialize bilevel optimizer.
- init(params: Any, hyperparams: Any) BilevelState[source]
Initialize the bilevel optimization state.
- Parameters:
params – Initial model parameters.
hyperparams – Initial hyperparameters.
- Returns:
Initial optimization state.
- compute_hypergradient(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) Any[source]
One-step hypergradient: g_FO + alpha @ B_t.
- step(state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_hyper: float | None = None) BilevelState[source]
- run(state: BilevelState, M: int, T: int, get_train_batch: Callable, get_val_batch: Callable, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array], lr_reptile: float = 1.0, lr_hyper: float | None = None, callback: Callable | None = None) BilevelState[source]