Source code for gradhpo.algorithms.t1t2

"""T1-T2 optimizer with numerical DARTS approximation.

Implements the T1-T2 algorithm from Luketina et al. (2016) with numerical
DARTS approximation as described in Liu et al. (2018).
"""
from functools import partial
from typing import Any, Callable, Optional

import jax
import jax.numpy as jnp
import optax

from gradhpo.core.base import BilevelOptimizer
from gradhpo.core.state import BilevelState
from gradhpo.core.types import PyTree, LossFn
from gradhpo.utils.gradients import tree_dot, update_w_star


[docs] class T1T2Optimizer(BilevelOptimizer): """T1-T2 optimizer with numerical DARTS approximation. Implements the T1-T2 algorithm with a fixed horizon length and uses numerical DARTS approximation for computing the hypergradient. Attributes: inner_optimizer: Optax optimizer for parameters (optional if update_fn given). outer_optimizer: Optax optimizer for hyperparameters (optional). gamma: EMA decay factor in [0, 1] for weight distillation. T: Number of inner steps per episode. update_fn: Optional custom inner step function Phi(w, lam, batch) -> w_new. eps: Small epsilon for numerical differentiation in DARTS approximation. """
[docs] def __init__( self, inner_optimizer: Optional[optax.GradientTransformation] = None, outer_optimizer: Optional[optax.GradientTransformation] = None, gamma: float = 0.9, T: int = 20, update_fn: Optional[Callable] = None, eps: float = 1e-4, ): """Initialize T1-T2 optimizer with numerical DARTS approximation. Args: inner_optimizer: Optax optimizer for parameters. outer_optimizer: Optax optimizer for hyperparameters. gamma: EMA decay factor for weight distillation. T: Inner steps per episode. update_fn: Custom inner step Phi(w, lam, batch) -> w_new. If provided, inner_optimizer is not used. eps: Epsilon for numerical differentiation in DARTS approximation. """ super().__init__(inner_optimizer, outer_optimizer) self.gamma = gamma self.T = T self._update_fn = update_fn self.eps = eps
def _get_inner_step_fn( self, state: BilevelState, train_loss_fn: LossFn, ) -> Callable: """Return the inner step function Phi(w, lam, batch) -> w_new. If a custom update_fn was provided, use it directly. Otherwise, build one from train_loss_fn + inner_optimizer. """ if self._update_fn is not None: return self._update_fn inner_opt = self.inner_optimizer inner_opt_state = state.inner_opt_state def phi(params, hyperparams, batch): grads = jax.grad(train_loss_fn, argnums=0)(params, hyperparams, batch) updates, _ = inner_opt.update(grads, inner_opt_state, params) return optax.apply_updates(params, updates) return phi
[docs] def init( self, params: PyTree, hyperparams: PyTree, ) -> BilevelState: """Initialize state with hypergradient memory. Args: params: Initial model parameters. hyperparams: Initial hyperparameters. Returns: Initial BilevelState. """ inner_opt_state = ( self.inner_optimizer.init(params) if self.inner_optimizer is not None else None ) outer_opt_state = ( self.outer_optimizer.init(hyperparams) if self.outer_optimizer is not None else None ) return BilevelState( params=params, hyperparams=hyperparams, inner_opt_state=inner_opt_state, outer_opt_state=outer_opt_state, step=0, metadata={ "w_star": params, }, )
[docs] def compute_hypergradient( self, state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: LossFn, val_loss_fn: LossFn, ) -> PyTree: """Compute hypergradient w.r.t. hyperparameters using numerical DARTS approximation. Uses the T1-T2 approximation with numerical DARTS: g = g_FO + sum_{t=1}^T (alpha_t @ dPhi(w_{t-1}, lam; D_t) / dlam) Args: state: Current state (must have w_star in metadata). train_batch: Training batch. val_batch: Validation batch. train_loss_fn: Training loss function. val_loss_fn: Validation loss function. Returns: Hypergradient with same structure as hyperparams. """ t = state.step w_star = state.get_metric("w_star") phi_fn = self._get_inner_step_fn(state, train_loss_fn) # w_t = Phi(w_{t-1}, lam; D_t) w_new = phi_fn(state.params, state.hyperparams, train_batch) # alpha_t = dL_val(w_t, lam) / dw_t alpha = jax.grad(val_loss_fn, argnums=0)( w_new, state.hyperparams, val_batch) # g_FO = dL_val / dlam (direct gradient) g_fo = jax.grad(val_loss_fn, argnums=1)( w_new, state.hyperparams, val_batch) # Compute v_t using numerical DARTS approximation # v_t = alpha @ dPhi(w_star, lam; D_t) / dlam (numerically approximated) w_star_new = update_w_star(w_star, state.params, self.gamma, t) v_t = self._numerical_darts_approximation( phi_fn, w_star_new, state.hyperparams, train_batch, alpha) # g = g_FO + v_t return jax.tree.map(lambda fo, v: fo + v, g_fo, v_t)
def _numerical_darts_approximation( self, phi_fn: Callable, w: PyTree, lam: PyTree, batch: Any, alpha: PyTree, ) -> PyTree: """Compute alpha @ dPhi/dlambda using numerical DARTS approximation. Uses finite differences to approximate the derivative: dPhi/dlambda ≈ [Phi(w, lam + eps*e_i, batch) - Phi(w, lam - eps*e_i, batch)] / (2*eps) Args: phi_fn: Inner step function Phi(w, lam, batch) -> w_new. w: Model parameters. lam: Hyperparameters. batch: Training batch. alpha: Gradient vector for VJP computation. Returns: alpha @ dPhi/dlambda computed via numerical approximation. """ flat_lam, tree_def = jax.tree_util.tree_flatten(lam) result_leaves = [] for i, lam_element in enumerate(flat_lam): shape_i = lam_element.shape def compute_element_gradient(pert_val): perturbed_lam_leaves = [ leaf + ( pert_val * jnp.ones(leaf.shape) if j == i else jnp.zeros(leaf.shape) ) for j, leaf in enumerate(flat_lam) ] perturbed_lam = jax.tree_util.tree_unflatten( tree_def, perturbed_lam_leaves) w_new = phi_fn(w, perturbed_lam, batch) return tree_dot(w_new, alpha) # Central difference: (f(x+eps) - f(x-eps)) / (2*eps). f_plus = compute_element_gradient(self.eps) f_minus = compute_element_gradient(-self.eps) gradient_element = (f_plus - f_minus) / (2 * self.eps) result_leaves.append(gradient_element * jnp.ones(shape_i)) return jax.tree_util.tree_unflatten(tree_def, result_leaves)
[docs] @partial(jax.jit, static_argnums=(0, 4, 5, 6)) def step( self, state: BilevelState, train_batch: Any, val_batch: Any, train_loss_fn: LossFn, val_loss_fn: LossFn, lr_hyper: Optional[float] = None, ) -> BilevelState: """Perform one T1-T2 optimization step with numerical DARTS approximation. Args: state: Current bilevel state. train_batch: Training data batch. val_batch: Validation data batch. train_loss_fn: Training loss function. val_loss_fn: Validation loss function. lr_hyper: Manual learning rate (used when outer_optimizer is None). Returns: Updated state. """ t = state.step + 1 w_star = state.get_metric("w_star") phi_fn = self._get_inner_step_fn(state, train_loss_fn) # 1. Update w_star (Eq. 13) w_star_new = update_w_star(w_star, state.params, self.gamma, t) # 2. Inner step: w_t = Phi(w_{t-1}, lam; D_t) w_new = phi_fn(state.params, state.hyperparams, train_batch) # 3. alpha_t and g_FO alpha = jax.grad(val_loss_fn, argnums=0)( w_new, state.hyperparams, val_batch) g_fo = jax.grad(val_loss_fn, argnums=1)( w_new, state.hyperparams, val_batch) # 4. v_t using numerical DARTS approximation v_t = self._numerical_darts_approximation( phi_fn, w_star_new, state.hyperparams, train_batch, alpha) # 5. Hypergradient hyper_grad = jax.tree.map(lambda fo, v: fo + v, g_fo, v_t) # 6. Update hyperparams if self.outer_optimizer is not None: updates, new_outer_opt_state = self.outer_optimizer.update( hyper_grad, state.outer_opt_state, state.hyperparams) lam_new = optax.apply_updates(state.hyperparams, updates) else: assert lr_hyper is not None, "lr_hyper required when outer_optimizer is None" lam_new = jax.tree.map( lambda l, g: l - lr_hyper * g, state.hyperparams, hyper_grad) new_outer_opt_state = state.outer_opt_state # 7. Update inner opt state (if using optax inner optimizer) new_inner_opt_state = state.inner_opt_state if self.inner_optimizer is not None and self._update_fn is None: grads = jax.grad(train_loss_fn, argnums=0)( state.params, state.hyperparams, train_batch) _, new_inner_opt_state = self.inner_optimizer.update( grads, state.inner_opt_state, state.params) return state.update( params=w_new, hyperparams=lam_new, inner_opt_state=new_inner_opt_state, outer_opt_state=new_outer_opt_state, step=t, metadata={ "w_star": w_star_new, }, )
[docs] def run( self, state: BilevelState, M: int, get_train_batch: Callable, get_val_batch: Callable, train_loss_fn: LossFn, val_loss_fn: LossFn, lr_reptile: float = 1.0, lr_hyper: Optional[float] = None, callback: Optional[Callable] = None, ) -> BilevelState: """Full T1-T2 training loop. Runs M episodes, each with T inner steps and a Reptile update. Args: state: Initial state (from init()). M: Number of outer episodes. get_train_batch: Callable returning a training batch. get_val_batch: Callable returning a validation batch. train_loss_fn: Training loss function. val_loss_fn: Validation loss function. lr_reptile: Reptile learning rate for weight initialization. lr_hyper: Manual hyper LR (used when outer_optimizer is None). callback: Optional callback(episode, state). Returns: Final BilevelState. """ phi = state.params for m in range(1, M + 1): # Reset step counter and params for new episode state = state.update( params=phi, step=0, metadata={"w_star": phi}, ) lr_current = ( lr_hyper * (1.0 - (m - 1) / M) if lr_hyper is not None else None ) # Inner optimization (T steps) for t in range(1, self.T + 1): train_batch = get_train_batch() val_batch = get_val_batch() state = self.step( state, train_batch, val_batch, train_loss_fn, val_loss_fn, lr_current) # Reptile update: phi <- phi - lr_reptile * (phi - w_T) phi = jax.tree.map( lambda p, wt: p - lr_reptile * (p - wt), phi, state.params) if callback is not None: val_b = get_val_batch() val_loss = float( val_loss_fn(state.params, state.hyperparams, val_b)) state = state.update( metadata={"val_loss": val_loss}, ) callback(m, state) state = state.update(params=phi) return state