Skip to content

Neural Ensemble Search (NAS)

Algorithms for searching diverse and high-performing architectures.

bensemble.search.nes.RandomSearcher

RandomSearcher(
    space: SearchSpace,
    pool_size: int = 50,
    ensemble_size: int = 5,
    train_fn: Optional[Callable[[Module], None]] = None,
    device: Optional[device] = None,
    criterion: Optional[
        Callable[[list[Module], DataLoader, device], float]
    ] = None,
)

NES with Random Search (NES-RS, Algorithm 2 in Zaidi et al. NeurIPS 2021).

Builds a pool of pool_size independently trained models with randomly sampled architectures, then applies greedy forward ensemble selection to pick the final ensemble of ensemble_size members.

Source code in bensemble/search/nes.py
def __init__(
    self,
    space: SearchSpace,
    pool_size: int = 50,
    ensemble_size: int = 5,
    train_fn: Optional[Callable[[nn.Module], None]] = None,
    device: Optional[torch.device] = None,
    criterion: Optional[
        Callable[[list[nn.Module], DataLoader, torch.device], float]
    ] = None,
) -> None:
    self.space = space
    self.pool_size = pool_size
    self.ensemble_size = ensemble_size
    self.train_fn = train_fn
    self.device = device or torch.device(
        "cuda" if torch.cuda.is_available() else "cpu"
    )
    self.criterion = (
        criterion if criterion is not None else classification_nll_criterion
    )

search

search(
    val_loader: DataLoader,
    val_loader_shift: Optional[DataLoader] = None,
) -> Ensemble

Run NES-RS and return the selected ensemble.

Parameters:

Name Type Description Default
val_loader DataLoader

Validation loader used for ensemble selection.

required
val_loader_shift Optional[DataLoader]

If provided, used instead of val_loader for the final ForwardSelect call (dataset-shift adaptation).

None

Returns:

Type Description
Ensemble

Ensemble of ensemble_size members.

Source code in bensemble/search/nes.py
def search(
    self,
    val_loader: DataLoader,
    val_loader_shift: Optional[DataLoader] = None,
) -> Ensemble:
    """Run NES-RS and return the selected ensemble.

    Args:
        val_loader: Validation loader used for ensemble selection.
        val_loader_shift: If provided, used instead of `val_loader` for
            the final ForwardSelect call (dataset-shift adaptation).

    Returns:
        Ensemble of `ensemble_size` members.
    """
    pool: list[nn.Module] = []
    for _ in range(self.pool_size):
        config = self.space.sample()
        model = self.space.build(config)
        self.train_fn(model)
        pool.append(model)

    selection_loader = (
        val_loader_shift if val_loader_shift is not None else val_loader
    )
    selected = forward_select(
        pool, selection_loader, self.ensemble_size, self.device, self.criterion
    )
    return Ensemble.from_models(selected)

bensemble.search.nes.EvolutionarySearcher

EvolutionarySearcher(
    space: SearchSpace,
    pool_size: int = 50,
    ensemble_size: int = 5,
    population_size: int = 10,
    num_parent_candidates: int = 3,
    train_fn: Optional[Callable[[Module], None]] = None,
    device: Optional[device] = None,
    criterion: Optional[
        Callable[[list[Module], DataLoader, device], float]
    ] = None,
)

NES with Regularized Evolution (NES-RE, Algorithm 1 in Zaidi et al. NeurIPS 2021).

Evolves a population of architectures using ensemble-aware parent selection (ForwardSelect on the population) and single-step mutation. The full history of trained models forms the pool from which the final ensemble is selected.

Source code in bensemble/search/nes.py
def __init__(
    self,
    space: SearchSpace,
    pool_size: int = 50,
    ensemble_size: int = 5,
    population_size: int = 10,
    num_parent_candidates: int = 3,
    train_fn: Optional[Callable[[nn.Module], None]] = None,
    device: Optional[torch.device] = None,
    criterion: Optional[
        Callable[[list[nn.Module], DataLoader, torch.device], float]
    ] = None,
) -> None:
    self.space = space
    self.pool_size = pool_size
    self.ensemble_size = ensemble_size
    self.population_size = population_size
    self.num_parent_candidates = num_parent_candidates
    self.train_fn = train_fn
    self.device = device or torch.device(
        "cuda" if torch.cuda.is_available() else "cpu"
    )
    self.criterion = (
        criterion if criterion is not None else classification_nll_criterion
    )

search

search(
    val_loader: DataLoader,
    val_loader_shift: Optional[DataLoader] = None,
) -> Ensemble

Run NES-RE and return the selected ensemble.

Parameters:

Name Type Description Default
val_loader DataLoader

Validation loader used for parent selection and final ensemble selection.

required
val_loader_shift Optional[DataLoader]

If provided, used instead of val_loader for the final ForwardSelect call (dataset-shift adaptation).

None

Returns:

Type Description
Ensemble

Ensemble of ensemble_size members.

Source code in bensemble/search/nes.py
def search(
    self,
    val_loader: DataLoader,
    val_loader_shift: Optional[DataLoader] = None,
) -> Ensemble:
    """Run NES-RE and return the selected ensemble.

    Args:
        val_loader: Validation loader used for parent selection and final
            ensemble selection.
        val_loader_shift: If provided, used instead of `val_loader` for
            the final ForwardSelect call (dataset-shift adaptation).

    Returns:
        Ensemble of `ensemble_size` members.
    """
    # Maps id(model) -> config to support mutation of selected parents.
    config_map: dict[int, dict] = {}

    def _build_and_train(config: dict) -> nn.Module:
        model = self.space.build(config)
        self.train_fn(model)
        config_map[id(model)] = config
        return model

    # --- Initialisation: seed population and pool ---
    population: deque[nn.Module] = deque()
    pool: list[nn.Module] = []

    for _ in range(self.population_size):
        model = _build_and_train(self.space.sample())
        population.append(model)
        pool.append(model)

    # --- Evolution loop ---
    while len(pool) < self.pool_size:
        # Select m parent candidates from the current population via ForwardSelect.
        parent_candidates = forward_select(
            list(population),
            val_loader,
            self.num_parent_candidates,
            self.device,
            self.criterion,
        )

        # Sample one parent uniformly at random.
        parent = random.choice(parent_candidates)
        parent_config = config_map[id(parent)]

        # Mutate and train child.
        child_config = self.space.mutate(parent_config)
        child = _build_and_train(child_config)

        population.append(child)
        pool.append(child)

        # Remove the oldest member from the population (regularized evolution).
        population.popleft()

    # --- Final ensemble selection ---
    selection_loader = (
        val_loader_shift if val_loader_shift is not None else val_loader
    )
    selected = forward_select(
        pool, selection_loader, self.ensemble_size, self.device, self.criterion
    )
    return Ensemble.from_models(selected)

Bayesian Sampling (SVGD)

bensemble.search.bayesian.NESBayesianSampler

NESBayesianSampler(
    space: SearchSpace,
    train_fn: Callable[[Module], None],
    pool_size: int = 50,
    ensemble_size: int = 5,
    temperature: float = 1.0,
    diversity_weight: float = 0.5,
    svgd_steps: int = 20,
    svgd_lr: float = 0.1,
    device: Optional[device] = None,
    criterion: Optional[
        Callable[[list[Module], DataLoader, device], float]
    ] = None,
)

Neural Ensemble Search via Bayesian Sampling (NESBS, Shu et al., UAI 2022). This implementation follows the paper's practical recipe: 1) build a candidate model pool from a search space; 2) estimate a posterior over candidates from validation losses; 3) select a diverse final ensemble either by: - weighted Monte Carlo sampling, or - an SVGD-inspired iterative refinement with diversity regularization.

Source code in bensemble/search/bayesian.py
def __init__(
    self,
    space: SearchSpace,
    train_fn: Callable[[nn.Module], None],
    pool_size: int = 50,
    ensemble_size: int = 5,
    temperature: float = 1.0,
    diversity_weight: float = 0.5,
    svgd_steps: int = 20,
    svgd_lr: float = 0.1,
    device: Optional[torch.device] = None,
    criterion: Optional[
        Callable[[list[nn.Module], DataLoader, torch.device], float]
    ] = None,
) -> None:
    if pool_size < 1:
        raise ValueError("pool_size must be >= 1.")
    if ensemble_size < 1:
        raise ValueError("ensemble_size must be >= 1.")
    if ensemble_size > pool_size:
        raise ValueError("ensemble_size must be <= pool_size.")
    if temperature <= 0:
        raise ValueError("temperature must be > 0.")
    if svgd_steps < 1:
        raise ValueError("svgd_steps must be >= 1.")

    self.space = space
    self.train_fn = train_fn
    self.pool_size = pool_size
    self.ensemble_size = ensemble_size
    self.temperature = temperature
    self.diversity_weight = diversity_weight
    self.svgd_steps = svgd_steps
    self.svgd_lr = svgd_lr
    self.device = device or torch.device(
        "cuda" if torch.cuda.is_available() else "cpu"
    )
    self.criterion = (
        criterion if criterion is not None else classification_nll_criterion
    )

sample_mc

sample_mc(val_loader: DataLoader) -> Ensemble

Parameters:

Name Type Description Default
val_loader DataLoader

Used to evaluate the posterior.

required

Returns:

Name Type Description
Ensemble Ensemble

The final ensemble wrapped in bensemble's core abstraction.

Source code in bensemble/search/bayesian.py
def sample_mc(self, val_loader: DataLoader) -> Ensemble:
    """
    Args:
        val_loader (DataLoader): Used to evaluate the posterior.

    Returns:
        Ensemble: The final ensemble wrapped in bensemble's core abstraction.
    """

    candidates = self._build_pool(val_loader)
    probs = self._posterior_probs(candidates)
    chosen = torch.multinomial(
        probs, num_samples=self.ensemble_size, replacement=False
    )
    models = [candidates[idx].model for idx in chosen.tolist()]
    return Ensemble.from_models(models)

sample_svgd

sample_svgd(val_loader: DataLoader) -> Ensemble

Parameters:

Name Type Description Default
val_loader DataLoader

Used to evaluate the architecture's loss/posterior.

required

Returns:

Name Type Description
Ensemble Ensemble

The final ensemble wrapped in bensemble's core abstraction.

Source code in bensemble/search/bayesian.py
def sample_svgd(self, val_loader: DataLoader) -> Ensemble:
    """
    Args:
        val_loader (DataLoader): Used to evaluate the architecture's loss/posterior.

    Returns:
        Ensemble: The final ensemble wrapped in bensemble's core abstraction.
    """
    candidates = self._build_pool(val_loader)
    posterior = self._posterior_probs(candidates)
    n = len(candidates)
    logits = torch.log(posterior + 1e-8)

    particles = torch.multinomial(
        posterior, num_samples=self.ensemble_size, replacement=False
    )
    for _ in range(self.svgd_steps):
        for i in range(self.ensemble_size):
            current = particles[i].item()
            best_idx = current
            best_value = float("-inf")

            for candidate_idx in range(n):
                if candidate_idx in particles.tolist() and candidate_idx != current:
                    continue
                repulsion = 0.0
                for j in range(self.ensemble_size):
                    if j == i:
                        continue
                    other_idx = particles[j].item()
                    div = self._pairwise_diversity(
                        candidates[candidate_idx].probs,
                        candidates[other_idx].probs,
                    )
                    repulsion += div
                value = logits[candidate_idx].item() + (
                    self.svgd_lr * self.diversity_weight * repulsion
                )
                if value > best_value:
                    best_value = value
                    best_idx = candidate_idx
            particles[i] = best_idx

    models = [candidates[idx].model for idx in particles.tolist()]
    return Ensemble.from_models(models)