selectors
Feature Selection Methods
Feature selector implementations.
- class SToG.selectors.STGLayer(input_dim: int, sigma: float = 0.5, device: str = 'cpu')[source]
Bases:
BaseFeatureSelectorStochastic Gates (STG) - Original implementation from Yamada et al. 2020. Uses Gaussian-based continuous relaxation of Bernoulli variables.
Reference: “Learning Feature Sparse Principal Subspace” (Yamada et al., ICML 2020)
- class SToG.selectors.STELayer(input_dim: int, device: str = 'cpu')[source]
Bases:
BaseFeatureSelectorStraight-Through Estimator for feature selection. Uses binary gates with gradient flow through sigmoid.
Reference: “Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation” (Bengio et al., 2013)
- class SToG.selectors.GumbelLayer(input_dim: int, temperature: float = 1.0, device: str = 'cpu')[source]
Bases:
BaseFeatureSelectorGumbel-Softmax based feature selector. Uses categorical distribution over {off, on} for each feature.
Reference: “Categorical Reparameterization with Gumbel-Softmax” (Jang et al., ICLR 2017)
Fixed implementation: Properly handles batch dimension and sampling.
- __init__(input_dim: int, temperature: float = 1.0, device: str = 'cpu')[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Bases:
BaseFeatureSelectorSTG with explicit handling of correlated features. Based on “Adaptive Group Sparse Regularization for Deep Neural Networks”. Uses group structure to handle feature correlation.
Reference: “Adaptive Group Sparse Regularization for Deep Neural Networks”
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Apply correlated stochastic gates to input features.
Compute regularization with correlation penalty.
Get selection probabilities for each feature.
- class SToG.selectors.L1Layer(input_dim: int, device: str = 'cpu')[source]
Bases:
BaseFeatureSelectorL1 regularization on input layer weights. Baseline comparison method for feature selection.
Overview
The SToG.selectors module implements five feature selection methods, each with different
properties and use cases.
Stochastic Gates (STGLayer)
- class SToG.selectors.STGLayer(input_dim: int, sigma: float = 0.5, device: str = 'cpu')[source]
Bases:
BaseFeatureSelectorStochastic Gates (STG) - Original implementation from Yamada et al. 2020. Uses Gaussian-based continuous relaxation of Bernoulli variables.
Reference: “Learning Feature Sparse Principal Subspace” (Yamada et al., ICML 2020)
Method: Gaussian-based continuous relaxation (Yamada et al., 2020)
- When to use:
Balanced accuracy and sparsity
Need smooth gradient flow
Stable training on most datasets
- Parameters:
sigma- Standard deviation of Gaussian noise (default: 0.5)Larger sigma: more exploration, potentially less sparse
Smaller sigma: deterministic behavior, faster convergence
Example:
from SToG import STGLayer
selector = STGLayer(input_dim=100, sigma=0.5)
Straight-Through Estimator (STELayer)
- class SToG.selectors.STELayer(input_dim: int, device: str = 'cpu')[source]
Bases:
BaseFeatureSelectorStraight-Through Estimator for feature selection. Uses binary gates with gradient flow through sigmoid.
Reference: “Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation” (Bengio et al., 2013)
Method: Binary gates with gradient approximation (Bengio et al., 2013)
- When to use:
Need explicit binary decisions (on/off)
Prefer fast convergence
Working with small feature sets
- Advantages:
Produces true binary gates at inference
Fast training convergence
Clear feature selection (no fuzzy boundaries)
- Disadvantages:
Gradient approximation may be biased
Can get stuck in local optima
May over-select features
Example:
from SToG import STELayer
selector = STELayer(input_dim=100)
Gumbel-Softmax (GumbelLayer)
- class SToG.selectors.GumbelLayer(input_dim: int, temperature: float = 1.0, device: str = 'cpu')[source]
Bases:
BaseFeatureSelectorGumbel-Softmax based feature selector. Uses categorical distribution over {off, on} for each feature.
Reference: “Categorical Reparameterization with Gumbel-Softmax” (Jang et al., ICLR 2017)
Fixed implementation: Properly handles batch dimension and sampling.
- __init__(input_dim: int, temperature: float = 1.0, device: str = 'cpu')[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Method: Categorical distribution relaxation (Jang et al., 2017)
- When to use:
Need principled probabilistic framework
Working with discrete latent variables
Can afford temperature annealing schedule
- Parameters:
temperature- Initial temperature (default: 1.0)Temperature annealing: \(\tau \to 0\) during training
Smaller temperature: more discrete behavior
- Advantages:
Theoretically grounded in Gumbel distribution
Flexible temperature schedule
Good for categorical problems
Example:
from SToG import GumbelLayer
selector = GumbelLayer(input_dim=100, temperature=1.0)
L1 Regularization (L1Layer)
- class SToG.selectors.L1Layer(input_dim: int, device: str = 'cpu')[source]
Bases:
BaseFeatureSelectorL1 regularization on input layer weights. Baseline comparison method for feature selection.
Method: Classical L1 penalty on feature weights
- When to use:
Baseline comparison
Want interpretable feature weights
Need simple, proven method
- How it works:
Learns feature weights \(w \in \mathbb{R}^d\)
Gates input: \(\tilde{x} = w \odot x\)
Encourages small weights via L1 penalty
- Advantages:
Simple and interpretable
Fast convergence
Well-studied statistical properties
- Disadvantages:
Soft selection (weights are continuous)
May not achieve exact sparsity
Features selected by magnitude, not binary gates
Example:
from SToG import L1Layer
selector = L1Layer(input_dim=100)
Method Comparison
Method |
Convergence |
Sparsity |
Interpretability |
Stability |
Use Case |
|---|---|---|---|---|---|
STG |
Medium |
Good |
Good |
High |
General purpose |
STE |
Fast |
Good |
Excellent |
Medium |
Binary selection |
Gumbel |
Medium |
Good |
Good |
Medium |
Categorical |
CorrelatedSTG |
Slow |
Excellent |
Good |
High |
Correlated features |
L1 |
Fast |
Fair |
Good |
High |
Baseline |