diff --git a/test/prototype/test_wanda_pp.py b/test/prototype/test_wanda_pp.py new file mode 100644 index 0000000000..2fad3aaa7c --- /dev/null +++ b/test/prototype/test_wanda_pp.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch import nn +from torch.testing._internal.common_pruning import SimpleLinear +from torch.testing._internal.common_utils import TestCase + +from torchao.prototype.sparsity.pruner.wanda_pp import WandaPlusPlusSparsifier + + +class TestWandaPlusPlusSparsifier(TestCase): + """Test Wanda++ Sparsifier""" + + def _setup_model_and_sparsifier(self, model, sparsifier, block_configs): + """Helper to setup model with calibration and forward pass""" + sparsifier.prepare(model, config=None) + + # Setup calibration for each block + for block_name, input_shape in block_configs.items(): + for _ in range(5): + calibration_input = torch.randn(1, *input_shape) + sparsifier.store_calibration_input(block_name, calibration_input) + + def _verify_sparsity(self, layer, expected, tolerance=0.02): + """Helper to verify sparsity level""" + actual = (layer.weight == 0).float().mean() + assert abs(actual - expected) < tolerance, ( + f"Expected ~{expected} sparsity, got {actual}" + ) + + def test_prepare_and_squash(self): + """Test preparation and cleanup inherit from Wanda""" + model = SimpleLinear() + sparsifier = WandaPlusPlusSparsifier() + sparsifier.prepare(model, config=None) + + # Should inherit Wanda's preparation + assert hasattr(sparsifier.groups[0]["module"], "activation_post_process") + + sparsifier.squash_mask() + assert not hasattr(sparsifier.groups[0]["module"], "activation_post_process") + + def test_one_layer_sparsity(self): + """Test single layer sparsification""" + model = nn.Sequential(nn.Linear(4, 1)) + model[0].weight.data = torch.tensor([[1, 2, 3, 4]], dtype=torch.float32) + + sparsifier = WandaPlusPlusSparsifier(sparsity_level=0.5) + self._setup_model_and_sparsifier(model, sparsifier, {"layer_0": (4,)}) + + sparsifier.set_context(model[0], "layer_0") + model(torch.tensor([[100, 10, 1, 0.1]], dtype=torch.float32)) + sparsifier.step() + sparsifier.squash_mask() + + self._verify_sparsity(model[0], 0.5) + + def test_multi_layer_sparsification(self): + """Test multi-layer sparsification""" + model = nn.Sequential(nn.Linear(128, 200), nn.ReLU(), nn.Linear(200, 10)) + sparsifier = WandaPlusPlusSparsifier(sparsity_level=0.5) + + block_configs = {"layer_0": (128,), "layer_2": (200,)} + self._setup_model_and_sparsifier(model, sparsifier, block_configs) + + model(torch.randn(100, 128)) + + # Sparsify each linear layer + for layer, block_name in [(model[0], "layer_0"), (model[2], "layer_2")]: + sparsifier.set_context(layer, block_name) + sparsifier.step() + self._verify_sparsity(layer, 0.5) + + sparsifier.squash_mask() + + def test_two_layer_mlp_unstructured_custom_config(self): + """Test custom config for selective sparsification""" + model = nn.Sequential(nn.Linear(128, 200), nn.ReLU(), nn.Linear(200, 10)) + config = [{"tensor_fqn": "0.weight"}] + + sparsifier = WandaPlusPlusSparsifier(sparsity_level=0.5) + sparsifier.prepare(model, config=config) + + # Only setup calibration for first layer + for _ in range(5): + sparsifier.store_calibration_input("layer_0", torch.randn(1, 128)) + + sparsifier.set_context(model[0], "layer_0") + model(torch.randn(100, 128)) + sparsifier.step() + + self._verify_sparsity(model[0], 0.5) + self._verify_sparsity(model[2], 0.0) + sparsifier.squash_mask() + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/prototype/sparsity/pruner/__init__.py b/torchao/prototype/sparsity/pruner/__init__.py index 9d7f775389..b35b567432 100644 --- a/torchao/prototype/sparsity/pruner/__init__.py +++ b/torchao/prototype/sparsity/pruner/__init__.py @@ -14,4 +14,5 @@ "BiasHook", "FakeStructuredSparsity", "SaliencyPruner", + "WandaPlusPlusSparsifier", ] diff --git a/torchao/prototype/sparsity/pruner/wanda_pp.py b/torchao/prototype/sparsity/pruner/wanda_pp.py new file mode 100644 index 0000000000..4039d1a220 --- /dev/null +++ b/torchao/prototype/sparsity/pruner/wanda_pp.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from torchao.sparsity import WandaSparsifier + +__all__ = ["WandaPlusPlusSparsifier"] + + +# TODO: Implement Regional Optimization (RO) +# TODO: Add `prepare` function for building quantization configs same as WandaSparsifier +class WandaPlusPlusSparsifier(WandaSparsifier): + r"""Wanda++ sparsifier extending Wanda with regional gradients + + Wanda++ (Pruning by Weights and activations with Regional Gradients), proposed in + https://arxiv.org/abs/2503.04992, extends the Wanda method by incorporating + regional gradients for more accurate pruning criteria. + + The sparsifier removes weights based on the Regional Gradient Score (RGS): + S_ij = (α * G_ij + ||X_j||_2) * |W_ij| + + where: + - G_ij: Regional gradient computed from L^l_RGS(X^l_n) = ||f^l(X^l_n)||_2 + - f^l: l-th decoder block function + - X^l_n: n-th input sample to the l-th decoder block + - α: Scaling factor for regional gradients (default: 100 from paper) + + Args: + alpha: Regional gradient scaling factor (default: 100 from paper) + calibration_samples: Number of samples for gradient computation (default: 32 from paper) + **kwargs: Arguments passed to WandaSparsifier + """ + + def __init__(self, alpha: float = 100.0, calibration_samples: int = 32, **kwargs): + super().__init__(**kwargs) + self.defaults.update( + {"alpha": alpha, "calibration_samples": calibration_samples} + ) + self._calibration_inputs = {} + self._current_decoder_block = None + self._current_block_name = None + + def store_calibration_input( + self, block_name: str, input_tensor: torch.Tensor + ) -> None: + """Store calibration inputs for regional gradient computation""" + if block_name not in self._calibration_inputs: + self._calibration_inputs[block_name] = [] + + if ( + len(self._calibration_inputs[block_name]) + < self.defaults["calibration_samples"] + ): + self._calibration_inputs[block_name].append(input_tensor.detach().clone()) + + def set_context(self, decoder_block: nn.Module, block_name: str) -> None: + """Set decoder block and block name for regional gradient computation""" + self._current_decoder_block = decoder_block + self._current_block_name = block_name + + def update_mask( + self, module: nn.Module, tensor_name: str, sparsity_level: float, **kwargs + ) -> None: + """Update mask using regional gradients (RO)""" + + # Step 1: get the tensor and the mask from the parametrizations + mask = getattr(module.parametrizations, tensor_name)[0].mask + tensor = getattr(module.parametrizations, tensor_name).original + + # Step 2: Compute regional gradients (RGS) + pruning_metric = self._compute_wandapp_metric(module, tensor, tensor_name) + + # Step 3: Apply sparsity using WandaSparsifier + self._apply_sparsity_pattern(mask, pruning_metric, sparsity_level, kwargs) + + def _compute_wandapp_metric( + self, module: nn.Module, tensor: torch.Tensor, tensor_name: str + ) -> torch.Tensor: + """Compute RO : (α * G_ij + ||X_j||_2) * |W_ij|""" + activation_norm_per_channel = module.activation_post_process.norm + regional_gradients = self._compute_regional_gradients(module, tensor_name) + + return ( + self.defaults["alpha"] * regional_gradients + + activation_norm_per_channel.unsqueeze(0) + ) * tensor.abs() + + def _compute_regional_gradients( + self, module: nn.Module, tensor_name: str + ) -> torch.Tensor: + """Compute regional gradients from calibration inputs""" + + inputs = self._calibration_inputs.get(self._current_block_name) + target_param = getattr(module.parametrizations, tensor_name).original + accumulated_gradients = torch.zeros_like(target_param) + + self._current_decoder_block.eval() + + # Compute L2-norm regional gradients + for input_tensor in inputs: + self._current_decoder_block.zero_grad() + with torch.enable_grad(): + output = self._current_decoder_block(input_tensor) + torch.norm(output, p=2).backward() + if target_param.grad is not None: + accumulated_gradients += target_param.grad.abs() + + return accumulated_gradients / len(inputs) diff --git a/torchao/sparsity/wanda.py b/torchao/sparsity/wanda.py index 7ad12a2d55..afee3aedf9 100644 --- a/torchao/sparsity/wanda.py +++ b/torchao/sparsity/wanda.py @@ -96,6 +96,17 @@ def update_mask( # type: ignore[override] # Step 2: Calculate Wx pruning_metric = torch.abs(tensor) * activation_norm_per_channel + # Step 3: Apply sparsity pattern + self._apply_sparsity_pattern(mask, pruning_metric, sparsity_level, kwargs) + + def _apply_sparsity_pattern( + self, + mask: torch.Tensor, + pruning_metric: torch.Tensor, + sparsity_level: float, + kwargs: dict, + ) -> None: + """Apply sparsity pattern based on pruning metric""" # defaults for unstructured sparsity block_size = pruning_metric.numel() num_specified = int(block_size * sparsity_level)