SToG: Stochastic Gating for Feature Selection

Feature selection using stochastic gating methods for neural networks

Welcome

SToG is a PyTorch library implementing stochastic gating methods for feature selection.

Key Methods

  • STG (Stochastic Gates) - Gaussian-based continuous relaxation

  • STE (Straight-Through Estimator) - Binary gates with gradient flow

  • Gumbel-Softmax - Categorical distribution relaxation

  • Correlated STG - For redundant/correlated features

  • L1 - Baseline L1 regularization

Quick Start

pip install stog
import torch
from mylib import STGLayer, FeatureSelectionTrainer, create_classification_model

# Create model and selector
model = create_classification_model(n_features=100, n_classes=2)
selector = STGLayer(n_features=100, sigma=0.5)

# Train
trainer = FeatureSelectionTrainer(
    model=model,
    selector=selector,
    criterion=torch.nn.CrossEntropyLoss(),
    lambda_reg=0.05
)

trainer.fit(X_train, y_train, X_val, y_val, epochs=300)
result = trainer.evaluate(X_test, y_test)

Next Steps

  • Installation - Installation guide

  • train - Training and benchmarking

  • API Reference - API Reference

  • info - About and citation