base

Base Feature Selector Class

Base classes for feature selection methods.

class SToG.base.BaseFeatureSelector(input_dim: int, device: str = 'cpu')[source]

Bases: Module, ABC

Base class for feature selection methods.

__init__(input_dim: int, device: str = 'cpu')[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

abstractmethod forward(x: Tensor) Tensor[source]

Apply feature selection gates to input.

Parameters:

x – Input tensor of shape [batch_size, input_dim]

Returns:

Gated input tensor of shape [batch_size, input_dim]

abstractmethod regularization_loss() Tensor[source]

Compute regularization loss for sparsity.

Returns:

Scalar tensor with regularization loss

abstractmethod get_selection_probs() Tensor[source]

Get feature selection probabilities.

Returns:

Tensor of shape [input_dim] with selection probabilities

get_selected_features(threshold: float = 0.5) ndarray[source]

Get binary mask of selected features.

Parameters:

threshold – Probability threshold for selection

Returns:

Boolean array of shape [input_dim]

Overview

The SToG.base.BaseFeatureSelector is an abstract base class that defines the interface for all feature selection methods in SToG. All concrete selector implementations must inherit from this class and implement the required abstract methods.

Key Methods

forward(x) - Apply feature gating

Applies learned gate parameters to input features, returning gated input.

regularization_loss() - Compute sparsity regularization

Returns a scalar loss that encourages sparse feature selection.

get_selection_probs() - Get selection probabilities

Returns per-feature selection probabilities used for determining which features are important.

get_selected_features(threshold) - Get binary selection mask

Returns a binary mask indicating selected vs. discarded features.

Design Pattern

All selectors follow this pattern:

  1. Initialization - Set up learnable parameters

  2. Forward pass - Apply gates to input during training/inference

  3. Regularization - Compute sparsity-inducing loss

  4. Interpretation - Extract feature importance from learned parameters

Example Implementation

from SToG.base import BaseFeatureSelector
import torch
import torch.nn as nn

class CustomSelector(BaseFeatureSelector):
    """Custom feature selector implementation."""

    def __init__(self, input_dim: int, device: str = 'cpu'):
        super().__init__(input_dim, device)
        self.weights = nn.Parameter(torch.randn(input_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Gate features using learned weights."""
        gates = torch.sigmoid(self.weights)
        return x * gates.unsqueeze(0)

    def regularization_loss(self) -> torch.Tensor:
        """Sparsity regularization."""
        return torch.sum(torch.sigmoid(self.weights))

    def get_selection_probs(self) -> torch.Tensor:
        """Selection probabilities."""
        return torch.sigmoid(self.weights).detach()