API Reference
Complete API documentation for SToG.
Module Overview
The SToG library consists of several interconnected modules:
- base.py - Abstract base class
Defines
SToG.base.BaseFeatureSelector, the abstract base for all feature selector implementations.- selectors.py - Feature selection methods
Implements five feature selection methods:
SToG.selectors.STGLayer- Stochastic Gates with Gaussian relaxationSToG.selectors.STELayer- Straight-Through EstimatorSToG.selectors.GumbelLayer- Gumbel-Softmax categorical relaxationSToG.selectors.CorrelatedSTGLayer- STG for correlated featuresSToG.selectors.L1Layer- L1 regularization baseline
- trainer.py - Training utilities
Provides
SToG.trainer.FeatureSelectionTrainerfor joint optimization of model and selector.- models.py - Model factories
Provides
SToG.models.create_classification_model()for creating neural network classifiers.- datasets.py - Dataset utilities
Provides
SToG.datasets.DatasetLoaderfor loading and preparing datasets.- benchmark.py - Benchmarking framework
Provides
SToG.benchmark.ComprehensiveBenchmarkfor comparing methods across datasets.- main.py - Main execution
Entry point for running benchmarks via
SToG.main.main().
Design Philosophy
Modular Architecture
Each feature selector inherits from SToG.base.BaseFeatureSelector, ensuring consistent interface:
BaseFeatureSelector (Abstract)
├── forward(x) -> x_gated
├── regularization_loss() -> scalar
├── get_selection_probs() -> probabilities
└── get_selected_features(threshold) -> mask
Extensibility
New feature selection methods can be implemented by subclassing SToG.base.BaseFeatureSelector
and implementing three methods: forward, regularization_loss, and get_selection_probs.
PyTorch Integration
All components are built on PyTorch:
Selectors inherit from
torch.nn.ModuleComputations use standard PyTorch tensors
Compatible with PyTorch’s optimization and autograd system