SToG.base

Base classes for feature selection methods.

Classes

BaseFeatureSelector(input_dim[, device])

Base class for feature selection methods.

class SToG.base.BaseFeatureSelector(input_dim: int, device: str = 'cpu')[source]

Bases: Module, ABC

Base class for feature selection methods.

__init__(input_dim: int, device: str = 'cpu')[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

abstractmethod forward(x: Tensor) Tensor[source]

Apply feature selection gates to input.

Parameters:

x – Input tensor of shape [batch_size, input_dim]

Returns:

Gated input tensor of shape [batch_size, input_dim]

abstractmethod regularization_loss() Tensor[source]

Compute regularization loss for sparsity.

Returns:

Scalar tensor with regularization loss

abstractmethod get_selection_probs() Tensor[source]

Get feature selection probabilities.

Returns:

Tensor of shape [input_dim] with selection probabilities

get_selected_features(threshold: float = 0.5) ndarray[source]

Get binary mask of selected features.

Parameters:

threshold – Probability threshold for selection

Returns:

Boolean array of shape [input_dim]