Utils (utils)

Utility functions for JAX pytree operations and hypergradient computation.

Gradient utilities: pytree operations, VJP helpers, and EMA updates.

gradhpo.utils.gradients.tree_l2_norm(tree: Any) Array[source]

L2 norm of a pytree.

gradhpo.utils.gradients.tree_normalize(tree: Any, eps: float = 1e-08) Any[source]

Normalize pytree to unit L2 norm.

gradhpo.utils.gradients.tree_dot(a: Any, b: Any) Array[source]

Inner product of two pytrees with the same structure.

gradhpo.utils.gradients.tree_zeros_like(tree: Any) Any[source]

Create a pytree of zeros with the same structure.

gradhpo.utils.gradients.tree_lerp(a: Any, b: Any, t: float) Any[source]

Linear interpolation: (1-t)*a + t*b.

gradhpo.utils.gradients.vjp_wrt_lambda(update_fn: Callable, w: Any, lam: Any, batch: Any, alpha: Any) Any[source]

Compute alpha @ dPhi/dlambda via VJP.

gradhpo.utils.gradients.vjp_wrt_both(update_fn: Callable, w: Any, lam: Any, batch: Any, alpha: Any) Tuple[Any, Any][source]

Compute alpha @ dPhi/dw and alpha @ dPhi/dlambda simultaneously.

gradhpo.utils.gradients.update_w_star(w_star: Any, w_prev: Any, gamma: float, t: int) Any[source]

Eq. 13: sequential update of distilled weight point.

t=1: w_1* = w_0 t>=2: w_t* = p_t * w_{t-1}* + (1 - p_t) * w_{t-1} where p_t = (gamma - gamma^t) / (1 - gamma^t).