"""T1-T2 optimizer with numerical DARTS approximation.
Implements the T1-T2 algorithm from Luketina et al. (2016) with numerical
DARTS approximation as described in Liu et al. (2018).
"""
from functools import partial
from typing import Any, Callable, Optional
import jax
import jax.numpy as jnp
import optax
from gradhpo.core.base import BilevelOptimizer
from gradhpo.core.state import BilevelState
from gradhpo.core.types import PyTree, LossFn
from gradhpo.utils.gradients import tree_dot, update_w_star
[docs]
class T1T2Optimizer(BilevelOptimizer):
"""T1-T2 optimizer with numerical DARTS approximation.
Implements the T1-T2 algorithm with a fixed horizon length and uses
numerical DARTS approximation for computing the hypergradient.
Attributes:
inner_optimizer: Optax optimizer for parameters (optional if update_fn given).
outer_optimizer: Optax optimizer for hyperparameters (optional).
gamma: EMA decay factor in [0, 1] for weight distillation.
T: Number of inner steps per episode.
update_fn: Optional custom inner step function Phi(w, lam, batch) -> w_new.
eps: Small epsilon for numerical differentiation in DARTS approximation.
"""
[docs]
def __init__(
self,
inner_optimizer: Optional[optax.GradientTransformation] = None,
outer_optimizer: Optional[optax.GradientTransformation] = None,
gamma: float = 0.9,
T: int = 20,
update_fn: Optional[Callable] = None,
eps: float = 1e-4,
):
"""Initialize T1-T2 optimizer with numerical DARTS approximation.
Args:
inner_optimizer: Optax optimizer for parameters.
outer_optimizer: Optax optimizer for hyperparameters.
gamma: EMA decay factor for weight distillation.
T: Inner steps per episode.
update_fn: Custom inner step Phi(w, lam, batch) -> w_new.
If provided, inner_optimizer is not used.
eps: Epsilon for numerical differentiation in DARTS approximation.
"""
super().__init__(inner_optimizer, outer_optimizer)
self.gamma = gamma
self.T = T
self._update_fn = update_fn
self.eps = eps
def _get_inner_step_fn(
self, state: BilevelState, train_loss_fn: LossFn,
) -> Callable:
"""Return the inner step function Phi(w, lam, batch) -> w_new.
If a custom update_fn was provided, use it directly.
Otherwise, build one from train_loss_fn + inner_optimizer.
"""
if self._update_fn is not None:
return self._update_fn
inner_opt = self.inner_optimizer
inner_opt_state = state.inner_opt_state
def phi(params, hyperparams, batch):
grads = jax.grad(train_loss_fn, argnums=0)(params, hyperparams, batch)
updates, _ = inner_opt.update(grads, inner_opt_state, params)
return optax.apply_updates(params, updates)
return phi
[docs]
def init(
self,
params: PyTree,
hyperparams: PyTree,
) -> BilevelState:
"""Initialize state with hypergradient memory.
Args:
params: Initial model parameters.
hyperparams: Initial hyperparameters.
Returns:
Initial BilevelState.
"""
inner_opt_state = (
self.inner_optimizer.init(params)
if self.inner_optimizer is not None
else None
)
outer_opt_state = (
self.outer_optimizer.init(hyperparams)
if self.outer_optimizer is not None
else None
)
return BilevelState(
params=params,
hyperparams=hyperparams,
inner_opt_state=inner_opt_state,
outer_opt_state=outer_opt_state,
step=0,
metadata={
"w_star": params,
},
)
[docs]
def compute_hypergradient(
self,
state: BilevelState,
train_batch: Any,
val_batch: Any,
train_loss_fn: LossFn,
val_loss_fn: LossFn,
) -> PyTree:
"""Compute hypergradient w.r.t. hyperparameters using numerical DARTS approximation.
Uses the T1-T2 approximation with numerical DARTS:
g = g_FO + sum_{t=1}^T (alpha_t @ dPhi(w_{t-1}, lam; D_t) / dlam)
Args:
state: Current state (must have w_star in metadata).
train_batch: Training batch.
val_batch: Validation batch.
train_loss_fn: Training loss function.
val_loss_fn: Validation loss function.
Returns:
Hypergradient with same structure as hyperparams.
"""
t = state.step
w_star = state.get_metric("w_star")
phi_fn = self._get_inner_step_fn(state, train_loss_fn)
# w_t = Phi(w_{t-1}, lam; D_t)
w_new = phi_fn(state.params, state.hyperparams, train_batch)
# alpha_t = dL_val(w_t, lam) / dw_t
alpha = jax.grad(val_loss_fn, argnums=0)(
w_new, state.hyperparams, val_batch)
# g_FO = dL_val / dlam (direct gradient)
g_fo = jax.grad(val_loss_fn, argnums=1)(
w_new, state.hyperparams, val_batch)
# Compute v_t using numerical DARTS approximation
# v_t = alpha @ dPhi(w_star, lam; D_t) / dlam (numerically approximated)
w_star_new = update_w_star(w_star, state.params, self.gamma, t)
v_t = self._numerical_darts_approximation(
phi_fn, w_star_new, state.hyperparams, train_batch, alpha)
# g = g_FO + v_t
return jax.tree.map(lambda fo, v: fo + v, g_fo, v_t)
def _numerical_darts_approximation(
self,
phi_fn: Callable,
w: PyTree,
lam: PyTree,
batch: Any,
alpha: PyTree,
) -> PyTree:
"""Compute alpha @ dPhi/dlambda using numerical DARTS approximation.
Uses finite differences to approximate the derivative:
dPhi/dlambda ≈ [Phi(w, lam + eps*e_i, batch) - Phi(w, lam - eps*e_i, batch)] / (2*eps)
Args:
phi_fn: Inner step function Phi(w, lam, batch) -> w_new.
w: Model parameters.
lam: Hyperparameters.
batch: Training batch.
alpha: Gradient vector for VJP computation.
Returns:
alpha @ dPhi/dlambda computed via numerical approximation.
"""
flat_lam, tree_def = jax.tree_util.tree_flatten(lam)
result_leaves = []
for i, lam_element in enumerate(flat_lam):
shape_i = lam_element.shape
def compute_element_gradient(pert_val):
perturbed_lam_leaves = [
leaf + (
pert_val * jnp.ones(leaf.shape)
if j == i else jnp.zeros(leaf.shape)
)
for j, leaf in enumerate(flat_lam)
]
perturbed_lam = jax.tree_util.tree_unflatten(
tree_def, perturbed_lam_leaves)
w_new = phi_fn(w, perturbed_lam, batch)
return tree_dot(w_new, alpha)
# Central difference: (f(x+eps) - f(x-eps)) / (2*eps).
f_plus = compute_element_gradient(self.eps)
f_minus = compute_element_gradient(-self.eps)
gradient_element = (f_plus - f_minus) / (2 * self.eps)
result_leaves.append(gradient_element * jnp.ones(shape_i))
return jax.tree_util.tree_unflatten(tree_def, result_leaves)
[docs]
@partial(jax.jit, static_argnums=(0, 4, 5, 6))
def step(
self,
state: BilevelState,
train_batch: Any,
val_batch: Any,
train_loss_fn: LossFn,
val_loss_fn: LossFn,
lr_hyper: Optional[float] = None,
) -> BilevelState:
"""Perform one T1-T2 optimization step with numerical DARTS approximation.
Args:
state: Current bilevel state.
train_batch: Training data batch.
val_batch: Validation data batch.
train_loss_fn: Training loss function.
val_loss_fn: Validation loss function.
lr_hyper: Manual learning rate (used when outer_optimizer is None).
Returns:
Updated state.
"""
t = state.step + 1
w_star = state.get_metric("w_star")
phi_fn = self._get_inner_step_fn(state, train_loss_fn)
# 1. Update w_star (Eq. 13)
w_star_new = update_w_star(w_star, state.params, self.gamma, t)
# 2. Inner step: w_t = Phi(w_{t-1}, lam; D_t)
w_new = phi_fn(state.params, state.hyperparams, train_batch)
# 3. alpha_t and g_FO
alpha = jax.grad(val_loss_fn, argnums=0)(
w_new, state.hyperparams, val_batch)
g_fo = jax.grad(val_loss_fn, argnums=1)(
w_new, state.hyperparams, val_batch)
# 4. v_t using numerical DARTS approximation
v_t = self._numerical_darts_approximation(
phi_fn, w_star_new, state.hyperparams, train_batch, alpha)
# 5. Hypergradient
hyper_grad = jax.tree.map(lambda fo, v: fo + v, g_fo, v_t)
# 6. Update hyperparams
if self.outer_optimizer is not None:
updates, new_outer_opt_state = self.outer_optimizer.update(
hyper_grad, state.outer_opt_state, state.hyperparams)
lam_new = optax.apply_updates(state.hyperparams, updates)
else:
assert lr_hyper is not None, "lr_hyper required when outer_optimizer is None"
lam_new = jax.tree.map(
lambda l, g: l - lr_hyper * g, state.hyperparams, hyper_grad)
new_outer_opt_state = state.outer_opt_state
# 7. Update inner opt state (if using optax inner optimizer)
new_inner_opt_state = state.inner_opt_state
if self.inner_optimizer is not None and self._update_fn is None:
grads = jax.grad(train_loss_fn, argnums=0)(
state.params, state.hyperparams, train_batch)
_, new_inner_opt_state = self.inner_optimizer.update(
grads, state.inner_opt_state, state.params)
return state.update(
params=w_new,
hyperparams=lam_new,
inner_opt_state=new_inner_opt_state,
outer_opt_state=new_outer_opt_state,
step=t,
metadata={
"w_star": w_star_new,
},
)
[docs]
def run(
self,
state: BilevelState,
M: int,
get_train_batch: Callable,
get_val_batch: Callable,
train_loss_fn: LossFn,
val_loss_fn: LossFn,
lr_reptile: float = 1.0,
lr_hyper: Optional[float] = None,
callback: Optional[Callable] = None,
) -> BilevelState:
"""Full T1-T2 training loop.
Runs M episodes, each with T inner steps and a Reptile update.
Args:
state: Initial state (from init()).
M: Number of outer episodes.
get_train_batch: Callable returning a training batch.
get_val_batch: Callable returning a validation batch.
train_loss_fn: Training loss function.
val_loss_fn: Validation loss function.
lr_reptile: Reptile learning rate for weight initialization.
lr_hyper: Manual hyper LR (used when outer_optimizer is None).
callback: Optional callback(episode, state).
Returns:
Final BilevelState.
"""
phi = state.params
for m in range(1, M + 1):
# Reset step counter and params for new episode
state = state.update(
params=phi,
step=0,
metadata={"w_star": phi},
)
lr_current = (
lr_hyper * (1.0 - (m - 1) / M)
if lr_hyper is not None
else None
)
# Inner optimization (T steps)
for t in range(1, self.T + 1):
train_batch = get_train_batch()
val_batch = get_val_batch()
state = self.step(
state, train_batch, val_batch,
train_loss_fn, val_loss_fn, lr_current)
# Reptile update: phi <- phi - lr_reptile * (phi - w_T)
phi = jax.tree.map(
lambda p, wt: p - lr_reptile * (p - wt), phi, state.params)
if callback is not None:
val_b = get_val_batch()
val_loss = float(
val_loss_fn(state.params, state.hyperparams, val_b))
state = state.update(
metadata={"val_loss": val_loss},
)
callback(m, state)
state = state.update(params=phi)
return state