Base Pruner
BaseNetDistributionPruner
Base pruner for NetDistribution. It will prune all weights wjoch have high probaility of being 0.
Source code in src/methods/bayes/base/net_distribution.py
__init__(net_distribution)
summary
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
net_distribution
|
dict[str, ParamDist]
|
posteror distribution for net which deside how probable zero value is |
required |
Source code in src/methods/bayes/base/net_distribution.py
prune(threshold)
Prune all weights which is prune estimation (log_z_test) is lower than threshold.
Source code in src/methods/bayes/base/net_distribution.py
prune_stats()
Get number of pruned parameters.
Returns:
| Name | Type | Description |
|---|---|---|
int |
int
|
number of pruned parameters |
Source code in src/methods/bayes/base/net_distribution.py
prune_weight(weight_name, threshold)
Prune weight if its prune estimation (log_z_test) is lower than threshold.
Source code in src/methods/bayes/base/net_distribution.py
set_weight_dropout_mask(weight_name, threshold)
Set weight's dropout mask if its prune estimation (log_z_test) is lower than threshold.
Source code in src/methods/bayes/base/net_distribution.py
total_params()
Get total number of parameters.
Returns:
| Name | Type | Description |
|---|---|---|
int |
int
|
total number of parameter |