trainer

Training Utilities

Training utilities for feature selection.

class SToG.trainer.FeatureSelectionTrainer(model, selector, criterion, lambda_reg=0.1, device='cpu')[source]

Bases: object

Trainer with proper lambda search and early stopping. Handles joint training of classification model and feature selector.

__init__(model, selector, criterion, lambda_reg=0.1, device='cpu')[source]

Initialize trainer.

Parameters:
  • model – Classification model (nn.Module)

  • selector – Feature selector (BaseFeatureSelector)

  • criterion – Loss function

  • lambda_reg – Regularization strength

  • device – Device to run on

train_epoch(X_train, y_train, X_val, y_val)[source]

Train for one epoch.

fit(X_train, y_train, X_val, y_val, epochs=300, patience=50, verbose=False)[source]

Train the model with early stopping.

Parameters:
  • X_train – Training features [N_train, D]

  • y_train – Training labels [N_train]

  • X_val – Validation features [N_val, D]

  • y_val – Validation labels [N_val]

  • epochs – Maximum number of epochs

  • patience – Early stopping patience

  • verbose – Whether to print progress

Returns:

Training history dictionary

evaluate(X_test, y_test)[source]

Evaluate on test set.

Parameters:
  • X_test – Test features [N_test, D]

  • y_test – Test labels [N_test]

Returns:

Dictionary with test metrics

FeatureSelectionTrainer

class SToG.trainer.FeatureSelectionTrainer(model, selector, criterion, lambda_reg=0.1, device='cpu')[source]

Bases: object

Trainer with proper lambda search and early stopping. Handles joint training of classification model and feature selector.

__init__(model, selector, criterion, lambda_reg=0.1, device='cpu')[source]

Initialize trainer.

Parameters:
  • model – Classification model (nn.Module)

  • selector – Feature selector (BaseFeatureSelector)

  • criterion – Loss function

  • lambda_reg – Regularization strength

  • device – Device to run on

train_epoch(X_train, y_train, X_val, y_val)[source]

Train for one epoch.

fit(X_train, y_train, X_val, y_val, epochs=300, patience=50, verbose=False)[source]

Train the model with early stopping.

Parameters:
  • X_train – Training features [N_train, D]

  • y_train – Training labels [N_train]

  • X_val – Validation features [N_val, D]

  • y_val – Validation labels [N_val]

  • epochs – Maximum number of epochs

  • patience – Early stopping patience

  • verbose – Whether to print progress

Returns:

Training history dictionary

evaluate(X_test, y_test)[source]

Evaluate on test set.

Parameters:
  • X_test – Test features [N_test, D]

  • y_test – Test labels [N_test]

Returns:

Dictionary with test metrics

Overview

The SToG.trainer.FeatureSelectionTrainer handles joint optimization of a classification model and a feature selector. It implements:

  • Two-optimizer approach - Separate optimizers for model and selector

  • Early stopping - Validation-based stopping with configurable patience

  • Gradient clipping - Prevents gradient explosion

  • History tracking - Records metrics for analysis

  • Model checkpointing - Saves best model state

Architecture

Input Data
    │
    ├─> Selector (Feature Gates)
    │        │
    │   [Gate Parameters]
    │
    └─> Model (Classifier)
            │
        [Model Parameters]
            │
       Output Logits
            │
       Classification Loss + Regularization Loss
            │
       ┌─────┴──────┐
       │             │
    Model        Selector
   Optimizer     Optimizer
   (lr=0.001)    (lr=0.01)
       │             │
       └─────┬───────┘
             │
      Update Parameters

Joint Loss Function

The trainer optimizes:

\[\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}}(\mathbf{f}, \mathbf{g}) + \lambda \Omega(\mathbf{g})\]

where:

  • \(\mathcal{L}_{\text{task}}\) is the classification loss (CrossEntropyLoss)

  • \(\Omega(\mathbf{g})\) is the regularization from selector

  • \(\lambda\) controls sparsity-accuracy trade-off

Two-Optimizer Strategy

Model Optimizer:
  • Lower learning rate (default: 0.001)

  • Updates classification parameters \(\mathbf{f}\)

  • Learns from task loss

Selector Optimizer:
  • Higher learning rate (default: 0.01)

  • Updates gate parameters \(\mathbf{g}\)

  • Learns from task + regularization loss

  • 10x higher learning rate enables faster adaptation

Early Stopping

Early stopping monitors