SToG.models

Model definitions for classification tasks.

Functions

create_classification_model(input_dim, ...)

Create a simple feedforward neural network for classification.

SToG.models.create_classification_model(input_dim: int, num_classes: int, hidden_dim: int = None) Module[source]

Create a simple feedforward neural network for classification.

Parameters:
  • input_dim – Number of input features

  • num_classes – Number of output classes

  • hidden_dim – Hidden layer dimension (auto-calculated if None)

Returns:

PyTorch sequential model