Core (core)

Types

Core type definitions for gradhpo.

Provides type aliases and data structures used across all algorithms.

class gradhpo.core.types.DataBatch(inputs: Array, targets: Array)[source]

Bases: NamedTuple

Structure for a data batch.

Variables:
  • inputs (jax.jaxlib._jax.Array) – Input features [batch_size, …]

  • targets (jax.jaxlib._jax.Array) – Target labels [batch_size, …]

inputs: Array

Alias for field number 0

targets: Array

Alias for field number 1

class gradhpo.core.types.LossFunctions(train_loss: Callable[[Any, Any, Any], Array], val_loss: Callable[[Any, Any, Any], Array])[source]

Bases: NamedTuple

Container for loss functions.

Variables:
  • train_loss (Callable[[Any, Any, Any], jax.jaxlib._jax.Array]) – Training loss function

  • val_loss (Callable[[Any, Any, Any], jax.jaxlib._jax.Array]) – Validation loss function

train_loss: Callable[[Any, Any, Any], Array]

Alias for field number 0

val_loss: Callable[[Any, Any, Any], Array]

Alias for field number 1

Optimization State

BilevelState: state container for bilevel optimization.

class gradhpo.core.state.BilevelState(params: Any, hyperparams: Any, inner_opt_state: Any, outer_opt_state: Any, step: int, metadata: Dict[str, ~typing.Any]=<factory>)[source]

Bases: object

State container for bilevel optimization process.

Registered as a JAX pytree so that instances can be passed through jax.jit, jax.grad, etc.

Pytree layout

Leavesparams, hyperparams, inner_opt_state,

outer_opt_state, and the values of metadata (in sorted-key order).

Aux data: step (int) and the sorted keys of metadata

(tuple of strings). Both are Python scalars / tuples and are therefore treated as static by JAX.

ivar params:

Model parameters (inner level).

vartype params:

Any

ivar hyperparams:

Hyperparameters to optimize (outer level).

vartype hyperparams:

Any

ivar inner_opt_state:

State of inner optimizer.

vartype inner_opt_state:

Any

ivar outer_opt_state:

State of outer optimizer.

vartype outer_opt_state:

Any

ivar step:

Current optimization step.

vartype step:

int

ivar metadata:

Additional information (losses, norms, etc.). Values may be JAX arrays or plain Python scalars.

vartype metadata:

Dict[str, Any]

params: Any
hyperparams: Any
inner_opt_state: Any
outer_opt_state: Any
step: int
metadata: Dict[str, Any]
classmethod create(params: Any, hyperparams: Any, inner_opt_state: Any, outer_opt_state: Any) BilevelState[source]

Create a new BilevelState with initial values.

Parameters:
  • params – Initial model parameters.

  • hyperparams – Initial hyperparameters.

  • inner_opt_state – Initial inner optimizer state.

  • outer_opt_state – Initial outer optimizer state.

Returns:

Initialized state object.

update(params: Any | None = None, hyperparams: Any | None = None, inner_opt_state: Any | None = None, outer_opt_state: Any | None = None, step: int | None = None, metadata: Dict[str, Any] | None = None) BilevelState[source]

Create updated state with new values.

Parameters:
  • params – New parameters (if None, keep current).

  • hyperparams – New hyperparameters (if None, keep current).

  • inner_opt_state – New inner optimizer state.

  • outer_opt_state – New outer optimizer state.

  • step – New step count.

  • metadata – New metadata to merge.

Returns:

New state object with updates.

get_metric(key: str, default: Any = None) Any[source]

Retrieve a metric from metadata.

Parameters:
  • key – Metric name.

  • default – Default value if key not found.

Returns:

Metric value or default.

__init__(params: Any, hyperparams: Any, inner_opt_state: Any, outer_opt_state: Any, step: int, metadata: Dict[str, ~typing.Any]=<factory>) None

Base Optimizer

BilevelOptimizer: abstract base class for bilevel hyperparameter optimizers.

class gradhpo.core.base.BilevelOptimizer(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None)[source]

Bases: ABC

Abstract base class for bilevel hyperparameter optimizers.

All algorithms inherit from this class and implement the required abstract methods.

Variables:
  • inner_optimizer – Optax optimizer for parameters.

  • outer_optimizer – Optax optimizer for hyperparameters.

__init__(inner_optimizer: GradientTransformation | None = None, outer_optimizer: GradientTransformation | None = None)[source]

Initialize bilevel optimizer.

abstractmethod init(params: Any, hyperparams: Any) BilevelState[source]

Initialize the bilevel optimization state.

Parameters:
  • params – Initial model parameters.

  • hyperparams – Initial hyperparameters.

Returns:

Initial optimization state.

abstractmethod step(state: BilevelState, train_batch: DataBatch, val_batch: DataBatch, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) BilevelState[source]

Perform one bilevel optimization step.

Parameters:
  • state – Current bilevel state.

  • train_batch – Training data batch.

  • val_batch – Validation data batch.

  • train_loss_fn – Training loss function.

  • val_loss_fn – Validation loss function.

Returns:

Updated state.

abstractmethod compute_hypergradient(state: BilevelState, train_batch: DataBatch, val_batch: DataBatch, train_loss_fn: Callable[[Any, Any, Any], Array], val_loss_fn: Callable[[Any, Any, Any], Array]) Any[source]

Compute hypergradient w.r.t. hyperparameters.

Parameters:
  • state – Current state.

  • train_batch – Training batch.

  • val_batch – Validation batch.

  • train_loss_fn – Training loss.

  • val_loss_fn – Validation loss.

Returns:

Hypergradient with same structure as hyperparams.