trainer
Training Utilities
Training utilities for feature selection.
- class SToG.trainer.FeatureSelectionTrainer(model, selector, criterion, lambda_reg=0.1, device='cpu')[source]
Bases:
objectTrainer with proper lambda search and early stopping. Handles joint training of classification model and feature selector.
- __init__(model, selector, criterion, lambda_reg=0.1, device='cpu')[source]
Initialize trainer.
- Parameters:
model – Classification model (nn.Module)
selector – Feature selector (BaseFeatureSelector)
criterion – Loss function
lambda_reg – Regularization strength
device – Device to run on
- fit(X_train, y_train, X_val, y_val, epochs=300, patience=50, verbose=False)[source]
Train the model with early stopping.
- Parameters:
X_train – Training features [N_train, D]
y_train – Training labels [N_train]
X_val – Validation features [N_val, D]
y_val – Validation labels [N_val]
epochs – Maximum number of epochs
patience – Early stopping patience
verbose – Whether to print progress
- Returns:
Training history dictionary
FeatureSelectionTrainer
- class SToG.trainer.FeatureSelectionTrainer(model, selector, criterion, lambda_reg=0.1, device='cpu')[source]
Bases:
objectTrainer with proper lambda search and early stopping. Handles joint training of classification model and feature selector.
- __init__(model, selector, criterion, lambda_reg=0.1, device='cpu')[source]
Initialize trainer.
- Parameters:
model – Classification model (nn.Module)
selector – Feature selector (BaseFeatureSelector)
criterion – Loss function
lambda_reg – Regularization strength
device – Device to run on
- fit(X_train, y_train, X_val, y_val, epochs=300, patience=50, verbose=False)[source]
Train the model with early stopping.
- Parameters:
X_train – Training features [N_train, D]
y_train – Training labels [N_train]
X_val – Validation features [N_val, D]
y_val – Validation labels [N_val]
epochs – Maximum number of epochs
patience – Early stopping patience
verbose – Whether to print progress
- Returns:
Training history dictionary
Overview
The SToG.trainer.FeatureSelectionTrainer handles joint optimization of a classification
model and a feature selector. It implements:
Two-optimizer approach - Separate optimizers for model and selector
Early stopping - Validation-based stopping with configurable patience
Gradient clipping - Prevents gradient explosion
History tracking - Records metrics for analysis
Model checkpointing - Saves best model state
Architecture
Input Data
│
├─> Selector (Feature Gates)
│ │
│ [Gate Parameters]
│
└─> Model (Classifier)
│
[Model Parameters]
│
Output Logits
│
Classification Loss + Regularization Loss
│
┌─────┴──────┐
│ │
Model Selector
Optimizer Optimizer
(lr=0.001) (lr=0.01)
│ │
└─────┬───────┘
│
Update Parameters
Joint Loss Function
The trainer optimizes:
where:
\(\mathcal{L}_{\text{task}}\) is the classification loss (CrossEntropyLoss)
\(\Omega(\mathbf{g})\) is the regularization from selector
\(\lambda\) controls sparsity-accuracy trade-off
Two-Optimizer Strategy
- Model Optimizer:
Lower learning rate (default: 0.001)
Updates classification parameters \(\mathbf{f}\)
Learns from task loss
- Selector Optimizer:
Higher learning rate (default: 0.01)
Updates gate parameters \(\mathbf{g}\)
Learns from task + regularization loss
10x higher learning rate enables faster adaptation
Early Stopping
Early stopping monitors