Skip to content

Base Pruner

BaseNetDistributionPruner

Base pruner for NetDistribution. It will prune all weights wjoch have high probaility of being 0.

Source code in src/methods/bayes/base/net_distribution.py
class BaseNetDistributionPruner:
    """
    Base pruner for NetDistribution. It will prune all weights wjoch have high probaility of being 0.
    """

    def __init__(self, net_distribution: BaseNetDistribution) -> None:
        """_summary_

        Args:
            net_distribution (dict[str, ParamDist]): posteror distribution for net which deside how probable zero value is
        """
        self.net_distribution = net_distribution
        self.dropout_mask_dict: dict[str, nn.Parameter] = {}
        for name_dist, dist in self.net_distribution.weight_distribution.items():
            self.dropout_mask_dict[name_dist] = nn.Parameter(
                torch.ones_like(dist.sample())
            )

    def prune(self, threshold: float | dict[str, float]) -> None:
        """
        Prune all weights which is prune estimation (log_z_test) is lower than threshold.
        """
        for weight_name in self.net_distribution.weight_distribution:
            weight_threshold = threshold
            if isinstance(weight_threshold, dict):
                weight_threshold = weight_threshold[weight_name]
            self.prune_weight(weight_name, weight_threshold)

    def prune_weight(self, weight_name: str, threshold: float) -> None:
        """
        Prune weight if its prune estimation (log_z_test) is lower than threshold.
        """
        self.set_weight_dropout_mask(weight_name, threshold)
        pt = get_attr(self.net_distribution.base_module, weight_name.split("."))
        pt = pt * self.dropout_mask_dict[weight_name]
        pt = nn.Parameter(pt)
        set_attr(self.net_distribution.base_module, weight_name.split("."), pt)

    def set_weight_dropout_mask(self, weight_name: str, threshold: float) -> None:
        """
        Set weight's dropout mask if its prune estimation (log_z_test) is lower than threshold.
        """
        dist = self.net_distribution.weight_distribution[weight_name]
        self.dropout_mask_dict[weight_name].data = 1.0 * (
            dist.log_z_test() >= threshold
        )

    def prune_stats(self) -> int:
        """
        Get number of pruned parameters.

        Returns:
            int: number of pruned parameters
        """
        prune_cnt = 0
        for dropout in self.dropout_mask_dict.values():
            prune_cnt += (1 - dropout).sum()
        return prune_cnt

    def total_params(self) -> int:
        """
        Get total number of parameters.

        Returns:
            int: total number of parameter
        """
        out = sum(p.numel() for p in self.dropout_mask_dict.values())
        return out

__init__(net_distribution)

summary

Parameters:

Name Type Description Default
net_distribution dict[str, ParamDist]

posteror distribution for net which deside how probable zero value is

required
Source code in src/methods/bayes/base/net_distribution.py
def __init__(self, net_distribution: BaseNetDistribution) -> None:
    """_summary_

    Args:
        net_distribution (dict[str, ParamDist]): posteror distribution for net which deside how probable zero value is
    """
    self.net_distribution = net_distribution
    self.dropout_mask_dict: dict[str, nn.Parameter] = {}
    for name_dist, dist in self.net_distribution.weight_distribution.items():
        self.dropout_mask_dict[name_dist] = nn.Parameter(
            torch.ones_like(dist.sample())
        )

prune(threshold)

Prune all weights which is prune estimation (log_z_test) is lower than threshold.

Source code in src/methods/bayes/base/net_distribution.py
def prune(self, threshold: float | dict[str, float]) -> None:
    """
    Prune all weights which is prune estimation (log_z_test) is lower than threshold.
    """
    for weight_name in self.net_distribution.weight_distribution:
        weight_threshold = threshold
        if isinstance(weight_threshold, dict):
            weight_threshold = weight_threshold[weight_name]
        self.prune_weight(weight_name, weight_threshold)

prune_stats()

Get number of pruned parameters.

Returns:

Name Type Description
int int

number of pruned parameters

Source code in src/methods/bayes/base/net_distribution.py
def prune_stats(self) -> int:
    """
    Get number of pruned parameters.

    Returns:
        int: number of pruned parameters
    """
    prune_cnt = 0
    for dropout in self.dropout_mask_dict.values():
        prune_cnt += (1 - dropout).sum()
    return prune_cnt

prune_weight(weight_name, threshold)

Prune weight if its prune estimation (log_z_test) is lower than threshold.

Source code in src/methods/bayes/base/net_distribution.py
def prune_weight(self, weight_name: str, threshold: float) -> None:
    """
    Prune weight if its prune estimation (log_z_test) is lower than threshold.
    """
    self.set_weight_dropout_mask(weight_name, threshold)
    pt = get_attr(self.net_distribution.base_module, weight_name.split("."))
    pt = pt * self.dropout_mask_dict[weight_name]
    pt = nn.Parameter(pt)
    set_attr(self.net_distribution.base_module, weight_name.split("."), pt)

set_weight_dropout_mask(weight_name, threshold)

Set weight's dropout mask if its prune estimation (log_z_test) is lower than threshold.

Source code in src/methods/bayes/base/net_distribution.py
def set_weight_dropout_mask(self, weight_name: str, threshold: float) -> None:
    """
    Set weight's dropout mask if its prune estimation (log_z_test) is lower than threshold.
    """
    dist = self.net_distribution.weight_distribution[weight_name]
    self.dropout_mask_dict[weight_name].data = 1.0 * (
        dist.log_z_test() >= threshold
    )

total_params()

Get total number of parameters.

Returns:

Name Type Description
int int

total number of parameter

Source code in src/methods/bayes/base/net_distribution.py
def total_params(self) -> int:
    """
    Get total number of parameters.

    Returns:
        int: total number of parameter
    """
    out = sum(p.numel() for p in self.dropout_mask_dict.values())
    return out