SToG: Stochastic Gating for Feature Selection
Feature selection using stochastic gating methods for neural networks
Welcome
SToG is a PyTorch library implementing stochastic gating methods for feature selection.
Key Methods
STG (Stochastic Gates) - Gaussian-based continuous relaxation
STE (Straight-Through Estimator) - Binary gates with gradient flow
Gumbel-Softmax - Categorical distribution relaxation
Correlated STG - For redundant/correlated features
L1 - Baseline L1 regularization
Quick Start
pip install stog
import torch
from mylib import STGLayer, FeatureSelectionTrainer, create_classification_model
# Create model and selector
model = create_classification_model(n_features=100, n_classes=2)
selector = STGLayer(n_features=100, sigma=0.5)
# Train
trainer = FeatureSelectionTrainer(
model=model,
selector=selector,
criterion=torch.nn.CrossEntropyLoss(),
lambda_reg=0.05
)
trainer.fit(X_train, y_train, X_val, y_val, epochs=300)
result = trainer.evaluate(X_test, y_test)
Next Steps
Installation - Installation guide
train - Training and benchmarking
API Reference - API Reference
info - About and citation