Utils (utils)
Utility functions for JAX pytree operations and hypergradient computation.
Gradient utilities: pytree operations, VJP helpers, and EMA updates.
- gradhpo.utils.gradients.tree_normalize(tree: Any, eps: float = 1e-08) Any[source]
Normalize pytree to unit L2 norm.
- gradhpo.utils.gradients.tree_dot(a: Any, b: Any) Array[source]
Inner product of two pytrees with the same structure.
- gradhpo.utils.gradients.tree_zeros_like(tree: Any) Any[source]
Create a pytree of zeros with the same structure.
- gradhpo.utils.gradients.tree_lerp(a: Any, b: Any, t: float) Any[source]
Linear interpolation: (1-t)*a + t*b.
- gradhpo.utils.gradients.vjp_wrt_lambda(update_fn: Callable, w: Any, lam: Any, batch: Any, alpha: Any) Any[source]
Compute alpha @ dPhi/dlambda via VJP.