Skip to content

Commit

Permalink
add helper function compute_sparsity
Browse files Browse the repository at this point in the history
  • Loading branch information
wenh06 committed Jun 9, 2024
1 parent 88efa51 commit 497e5fd
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions fl_sim/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"is_notebook",
"find_longest_common_substring",
"add_kwargs",
"compute_sparsity",
]


Expand Down Expand Up @@ -517,3 +518,27 @@ def execute_cmd(cmd: Union[str, List[str]], raise_error: bool = True) -> Tuple[i
exitcode = 0

return exitcode, output_msg


def compute_sparsity(model_or_tensor: Union[torch.nn.Module, torch.Tensor]) -> float:
"""
Compute the sparsity of a model or tensor.
Parameters
----------
model_or_tensor : torch.nn.Module or torch.Tensor
A model or tensor.
Returns
-------
float
The sparsity of the model or tensor.
"""
nonzeros, n_params = 0, 0
if isinstance(model_or_tensor, torch.Tensor):
return model_or_tensor.abs().sign().sum().item() / model_or_tensor.numel()
for param in model_or_tensor.parameters():
n_params += param.data.numel()
nonzeros += param.abs().sign().sum().int().item()
return nonzeros / n_params

0 comments on commit 497e5fd

Please sign in to comment.