Skip to content

Commit

Permalink
Merge pull request pytorch#22 from pytorch-labs/jcaip/sparsity
Browse files Browse the repository at this point in the history
[sparse] add sparsity, add wanda sparsifier to ao
  • Loading branch information
jcaip committed Dec 11, 2023
2 parents a55e2d2 + 9c84256 commit fc6b3dc
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 0 deletions.
114 changes: 114 additions & 0 deletions test/sparsity/test_wanda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import logging
import unittest

import torch
from torch import nn
from torchao.sparsity import WandaSparsifier
from torch.ao.pruning import FakeSparsity
from torch.nn.utils.parametrize import is_parametrized
from torch.testing._internal.common_pruning import SimpleLinear
from torch.testing._internal.common_utils import TestCase

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)


class TestWandaSparsifier(TestCase):
"""
Test Wanda Sparsifier
"""

def test_prepare(self):
model = SimpleLinear()
sparsifier = WandaSparsifier()
sparsifier.prepare(model, config=None)
for g in sparsifier.groups:
module = g["module"]
# Check mask exists
assert hasattr(module.parametrizations["weight"][0], "mask")
# Check parametrization exists and is correct
assert is_parametrized(module, "weight")
assert type(module.parametrizations.weight[0]) == FakeSparsity
# check activation observer is present
assert hasattr(module, "activation_post_process")

def test_squash_mask(self):
# check observers and parameterizations removed
model = SimpleLinear()
sparsifier = WandaSparsifier()
sparsifier.prepare(model, config=None)
sparsifier.squash_mask()
for g in sparsifier.groups:
module = g["module"]
assert not is_parametrized(module, "weight")
assert not hasattr(module, "mask")
assert not hasattr(module, "activation_post_process")

def test_one_layer_mlp_2x4(self):
model = nn.Sequential(nn.Linear(8, 1))
weights = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]])
model[0].weight.data.copy_(weights.data)
X = torch.ones(1, 8)

sparsifier = WandaSparsifier(semi_structured_block_size=4)
sparsifier.prepare(model, config=None)

model(X)

sparsifier.step()
sparsifier.squash_mask()

sparsity = (model[0].weight == 0).float().mean()
assert sparsity == 0.5

expected_fc = torch.tensor([[0, 0, 3, 4, 0, 0, 7, 8]], dtype=torch.float32)
assert torch.allclose(model[0].weight.data, expected_fc, rtol=1e-05, atol=1e-07)

def test_one_layer_mlp_unstructured(self):
model = nn.Sequential(nn.Linear(4, 1))
weights = torch.tensor([[1, 2, 3, 4]], dtype=torch.float32)
model[0].weight.data.copy_(weights.data)
X = torch.tensor([[100, 10, 1, 0.1]], dtype=torch.float32)

sparsifier = WandaSparsifier(sparsity_level=0.5)
sparsifier.prepare(model, config=None)

model(X)

sparsifier.step()
sparsifier.squash_mask()

sparsity = (model[0].weight == 0).float().mean()
assert sparsity == 0.5

expected_fc = torch.tensor([[1, 2, 0, 0]], dtype=torch.float32)
assert torch.allclose(model[0].weight.data, expected_fc, rtol=1e-05, atol=1e-07)

def test_two_layer_mlp_unstructured(self):
model = nn.Sequential(
nn.Linear(128, 200), nn.ReLU(), nn.Linear(200, 10)
) # C_in by C_out
X1 = torch.randn(100, 128) # B1 by C_in
X2 = torch.randn(50, 128) # B2 by C_in

sparsifier = WandaSparsifier(sparsity_level=0.5)
sparsifier.prepare(model, config=None)

model(X1)
model(X2)
sparsifier.step()

cnt = 0
for m in model.modules():
if isinstance(m, nn.Linear):
cnt += 1
sparsity_level = (m.weight == 0).float().mean()
assert (
sparsity_level == 0.5
), f"sparsity for linear layer {cnt} should be 0.5"

sparsifier.squash_mask()

if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions torchao/sparsity/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .wanda import WandaSparsifier # noqa: F403
from .utils import PerChannelNormObserver # noqa: F403

__all__ = [
"WandaSparsifier",
"PerChannelNormObserver"
]
48 changes: 48 additions & 0 deletions torchao/sparsity/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from torch.ao.quantization.observer import UniformQuantizationObserverBase

__all__ = ["PerChannelNormObserver"]

# Observers
class PerChannelNormObserver(UniformQuantizationObserverBase):
"""
A custom observer that computes the L2 norm of each channel and stores it in a buffer.
"""

def __init__(self, **kwargs) -> None:
# init with fixed qparams for quantization flow
super().__init__(
dtype=torch.quint8,
qscheme=torch.per_channel_affine,
reduce_range=False,
quant_min=None,
quant_max=None,
eps=torch.finfo(torch.float32).eps,
**kwargs
)
# set averaging constant so quantization flow knows observer is memoryless.
self.averaging_constant = 1.0
self.register_buffer("norm", torch.tensor([]))

def forward(self, x_orig):
if x_orig.numel() == 0:
return x_orig
x = x_orig.detach() # avoid keeping autograd tape

# channel_ax is always the last dimension
new_axis_list = [i for i in range(x.dim())] # noqa: C416
new_axis_list[0], new_axis_list[-1] = new_axis_list[-1], new_axis_list[0]
y = x.permute(new_axis_list)
y = torch.flatten(y, start_dim=1)
norm = torch.norm(y, dim=1) ** 2

if self.norm.numel() == 0:
self.norm.resize_(norm.shape)
self.norm.copy_(norm)
else:
self.norm += norm

return x_orig

def calculate_qparams(self):
raise NotImplementedError("PerChannelNormObserver is designed to store activations only. ")
110 changes: 110 additions & 0 deletions torchao/sparsity/wanda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@

import warnings

from typing import Dict, List, Optional, Tuple

import torch
from torch import nn
from torch.ao.pruning import BaseSparsifier
from torch.ao.quantization import default_placeholder_observer, QConfig
from torch.ao.quantization.quantize import _remove_qconfig
from .utils import PerChannelNormObserver

__all__ = ["WandaSparsifier"]


class WandaSparsifier(BaseSparsifier):
r"""Wanda sparsifier
Wanda (Pruning by Weights and activations), proposed in https://arxiv.org/abs/2306.11695
is an activation aware pruning method. The sparsifier removes weights based on the product
of the input activation norm and the weight magnitude.
This sparsifier is controlled by three variables:
1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out;
Args:
sparsity_level: The target level of sparsity;
model: The model to be sparsified;
"""

def __init__(
self,
sparsity_level: float = 0.5,
semi_structured_block_size: Optional[int] = None,
):
defaults = {
"sparsity_level": sparsity_level,
"semi_structured_block_size": semi_structured_block_size,
}
if semi_structured_block_size is not None:
m = semi_structured_block_size
warnings.warn(
f"WandaSparsifier got semi_structured_bock_size={m}, sparsity_level fixed to 50% ({m // 2}:{m}) sparsity"
)
super().__init__(defaults=defaults)

def prepare(self, model: nn.Module, config: List[Dict]) -> None:
# activation: use PerChannelNormObserver
# use no-op placeholder weight observer
model.qconfig = QConfig(
activation=PerChannelNormObserver, weight=default_placeholder_observer
) # type: ignore[assignment]
torch.ao.quantization.prepare(model, inplace=True)

# call superclass prepare
super().prepare(model, config)

def update_mask( # type: ignore[override]
self, module: nn.Module, tensor_name: str, sparsity_level: float, **kwargs
) -> None:
r"""Pruning function for WandaSparsifier
The activation statistics is retrieved first in the `act_per_input` variable.
Then the Wanda pruning metric is computed. The weight matrix is then pruned
by comparing this metric across the whole current layer.
"""

# Step 1: get the tensor and the mask from the parametrizations
mask = getattr(module.parametrizations, tensor_name)[0].mask
tensor = getattr(module.parametrizations, tensor_name).original
activation_norm_per_channel = module.activation_post_process.norm

# Step 2: Calculate Wx
pruning_metric = torch.abs(tensor) * activation_norm_per_channel

# defaults for unstructured sparsity
block_size = pruning_metric.numel()
num_specified = int(block_size * sparsity_level)
# if set to use semi-structured, ignore sparsity_level
if kwargs.get("semi_structured_block_size", None) is not None:
block_size = kwargs["semi_structured_block_size"]
num_specified = block_size // 2

# get indicies to prune
pruning_inds = pruning_metric.view(-1, block_size).argsort(dim=1)[
:, :num_specified
]
# update mask
mask.data.view(-1, block_size).scatter_(
1, pruning_inds, torch.zeros_like(pruning_inds, dtype=mask.dtype)
)

def squash_mask(
self,
params_to_keep: Optional[Tuple[str, ...]] = None,
params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None,
*args,
**kwargs,
):
# remove quantization config
for config in self.groups:
module = config["module"]
tensor_name = config["tensor_name"]
_remove_qconfig(module)

# remove parameterizations
super().squash_mask(
params_to_keep=params_to_keep,
params_to_keep_per_layer=params_to_keep_per_layer,
)

0 comments on commit fc6b3dc

Please sign in to comment.