SToG.trainer
Training utilities for feature selection.
Classes
|
Trainer with proper lambda search and early stopping. |
- class SToG.trainer.FeatureSelectionTrainer(model, selector, criterion, lambda_reg=0.1, device='cpu')[source]
Bases:
objectTrainer with proper lambda search and early stopping. Handles joint training of classification model and feature selector.
- __init__(model, selector, criterion, lambda_reg=0.1, device='cpu')[source]
Initialize trainer.
- Parameters:
model – Classification model (nn.Module)
selector – Feature selector (BaseFeatureSelector)
criterion – Loss function
lambda_reg – Regularization strength
device – Device to run on
- fit(X_train, y_train, X_val, y_val, epochs=300, patience=50, verbose=False)[source]
Train the model with early stopping.
- Parameters:
X_train – Training features [N_train, D]
y_train – Training labels [N_train]
X_val – Validation features [N_val, D]
y_val – Validation labels [N_val]
epochs – Maximum number of epochs
patience – Early stopping patience
verbose – Whether to print progress
- Returns:
Training history dictionary