base
Base Feature Selector Class
Base classes for feature selection methods.
- class SToG.base.BaseFeatureSelector(input_dim: int, device: str = 'cpu')[source]
-
Base class for feature selection methods.
- __init__(input_dim: int, device: str = 'cpu')[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- abstractmethod forward(x: Tensor) Tensor[source]
Apply feature selection gates to input.
- Parameters:
x – Input tensor of shape [batch_size, input_dim]
- Returns:
Gated input tensor of shape [batch_size, input_dim]
- abstractmethod regularization_loss() Tensor[source]
Compute regularization loss for sparsity.
- Returns:
Scalar tensor with regularization loss
Overview
The SToG.base.BaseFeatureSelector is an abstract base class that defines the interface
for all feature selection methods in SToG. All concrete selector implementations must inherit
from this class and implement the required abstract methods.
Key Methods
- forward(x) - Apply feature gating
Applies learned gate parameters to input features, returning gated input.
- regularization_loss() - Compute sparsity regularization
Returns a scalar loss that encourages sparse feature selection.
- get_selection_probs() - Get selection probabilities
Returns per-feature selection probabilities used for determining which features are important.
- get_selected_features(threshold) - Get binary selection mask
Returns a binary mask indicating selected vs. discarded features.
Design Pattern
All selectors follow this pattern:
Initialization - Set up learnable parameters
Forward pass - Apply gates to input during training/inference
Regularization - Compute sparsity-inducing loss
Interpretation - Extract feature importance from learned parameters
Example Implementation
from SToG.base import BaseFeatureSelector
import torch
import torch.nn as nn
class CustomSelector(BaseFeatureSelector):
"""Custom feature selector implementation."""
def __init__(self, input_dim: int, device: str = 'cpu'):
super().__init__(input_dim, device)
self.weights = nn.Parameter(torch.randn(input_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Gate features using learned weights."""
gates = torch.sigmoid(self.weights)
return x * gates.unsqueeze(0)
def regularization_loss(self) -> torch.Tensor:
"""Sparsity regularization."""
return torch.sum(torch.sigmoid(self.weights))
def get_selection_probs(self) -> torch.Tensor:
"""Selection probabilities."""
return torch.sigmoid(self.weights).detach()