Core (core)
Types
Core type definitions for gradhpo.
Provides type aliases and data structures used across all algorithms.
- class gradhpo.core.types.DataBatch(inputs: Array, targets: Array)[source]
Bases:
NamedTupleStructure for a data batch.
- Variables:
inputs (jax.jaxlib._jax.Array) – Input features [batch_size, …]
targets (jax.jaxlib._jax.Array) – Target labels [batch_size, …]
- inputs: Array
Alias for field number 0
- targets: Array
Alias for field number 1
- class gradhpo.core.types.LossFunctions(train_loss: Callable[[Any, Any, Any], Array], val_loss: Callable[[Any, Any, Any], Array])[source]
Bases:
NamedTupleContainer for loss functions.
- Variables:
train_loss (Callable[[Any, Any, Any], jax.jaxlib._jax.Array]) – Training loss function
val_loss (Callable[[Any, Any, Any], jax.jaxlib._jax.Array]) – Validation loss function
Optimization State
BilevelState: state container for bilevel optimization.
- class gradhpo.core.state.BilevelState(params: Any, hyperparams: Any, inner_opt_state: Any, outer_opt_state: Any, step: int, metadata: Dict[str, ~typing.Any]=<factory>)[source]
Bases:
objectState container for bilevel optimization process.
Registered as a JAX pytree so that instances can be passed through
jax.jit,jax.grad, etc.Pytree layout
- Leaves
params,hyperparams,inner_opt_state, outer_opt_state, and the values ofmetadata(in sorted-key order).- Aux data:
step(int) and the sorted keys ofmetadata (tuple of strings). Both are Python scalars / tuples and are therefore treated as static by JAX.
- ivar params:
Model parameters (inner level).
- vartype params:
Any
- ivar hyperparams:
Hyperparameters to optimize (outer level).
- vartype hyperparams:
Any
- ivar inner_opt_state:
State of inner optimizer.
- vartype inner_opt_state:
Any
- ivar outer_opt_state:
State of outer optimizer.
- vartype outer_opt_state:
Any
- ivar step:
Current optimization step.
- vartype step:
int
- ivar metadata:
Additional information (losses, norms, etc.). Values may be JAX arrays or plain Python scalars.
- vartype metadata:
Dict[str, Any]
- classmethod create(params: Any, hyperparams: Any, inner_opt_state: Any, outer_opt_state: Any) BilevelState[source]
Create a new BilevelState with initial values.
- Parameters:
params – Initial model parameters.
hyperparams – Initial hyperparameters.
inner_opt_state – Initial inner optimizer state.
outer_opt_state – Initial outer optimizer state.
- Returns:
Initialized state object.
- update(params: Any | None = None, hyperparams: Any | None = None, inner_opt_state: Any | None = None, outer_opt_state: Any | None = None, step: int | None = None, metadata: Dict[str, Any] | None = None) BilevelState[source]
Create updated state with new values.
- Parameters:
params – New parameters (if None, keep current).
hyperparams – New hyperparameters (if None, keep current).
inner_opt_state – New inner optimizer state.
outer_opt_state – New outer optimizer state.
step – New step count.
metadata – New metadata to merge.
- Returns:
New state object with updates.
- Leaves
Base Optimizer
BilevelOptimizer: abstract base class for bilevel hyperparameter optimizers.
- class gradhpo.core.base.BilevelOptimizer(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None)[source]
Bases:
ABCAbstract base class for bilevel hyperparameter optimizers.
All algorithms inherit from this class and implement the required abstract methods.
- Variables:
inner_optimizer – Optax optimizer for parameters.
outer_optimizer – Optax optimizer for hyperparameters.
- __init__(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None)[source]
Initialize bilevel optimizer.
- abstractmethod init(params: Any, hyperparams: Any) BilevelState[source]
Initialize the bilevel optimization state.
- Parameters:
params – Initial model parameters.
hyperparams – Initial hyperparameters.
- Returns:
Initial optimization state.
- abstractmethod step(state: BilevelState, train_batch: DataBatch, val_batch: DataBatch, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) BilevelState[source]
Perform one bilevel optimization step.
- Parameters:
state – Current bilevel state.
train_batch – Training data batch.
val_batch – Validation data batch.
train_loss_fn – Training loss function.
val_loss_fn – Validation loss function.
- Returns:
Updated state.
- abstractmethod compute_hypergradient(state: BilevelState, train_batch: DataBatch, val_batch: DataBatch, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) Any[source]
Compute hypergradient w.r.t. hyperparameters.
- Parameters:
state – Current state.
train_batch – Training batch.
val_batch – Validation batch.
train_loss_fn – Training loss.
val_loss_fn – Validation loss.
- Returns:
Hypergradient with same structure as hyperparams.