Skip to content

Commit

Permalink
Move files to prototype/sparsity (#1145)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Oct 23, 2024
1 parent d252612 commit 4ef024c
Show file tree
Hide file tree
Showing 62 changed files with 3,557 additions and 3,209 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
import torch
import unittest

import torch
from torch import nn
from torch.nn.utils import parametrize
from torch.testing._internal.common_utils import TestCase

from torchao.sparsity.prototype.sparsifier import utils
from torchao.prototype.sparsity.sparsifier import utils

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
Expand Down Expand Up @@ -168,5 +169,6 @@ def test_jit_trace(self):
y_hat = model_trace(x)
self.assertEqual(y_hat, y)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import warnings
import unittest
import warnings

from torch import nn
from torch.testing._internal.common_utils import TestCase

from torchao.sparsity.prototype import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier
from torchao.prototype.sparsity import (
BaseScheduler,
CubicSL,
LambdaSL,
WeightNormSparsifier,
)


class ImplementedScheduler(BaseScheduler):
def get_sl(self):
Expand Down Expand Up @@ -190,5 +196,6 @@ def test_step(self):
msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
)

from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_
from torchao.utils import TORCH_VERSION_AFTER_2_5, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import (
TORCH_VERSION_AFTER_2_5,
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
)


logging.basicConfig(
Expand Down Expand Up @@ -88,7 +93,7 @@ def test_quant_semi_sparse(self, compile):
def test_sparse_marlin(self, compile):
if not torch.backends.cusparselt.is_available():
self.skipTest("Need cuSPARSELt")

input = torch.rand((256, 256)).half().cuda()
model = (
nn.Sequential(
Expand Down Expand Up @@ -117,7 +122,10 @@ def test_sparse_marlin(self, compile):


class TestBlockSparseWeight(common_utils.TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature due to need for custom op support")
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4,
"pytorch 2.4+ feature due to need for custom op support",
)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("compile", [True, False])
def test_sparse(self, compile):
Expand All @@ -140,7 +148,7 @@ def test_sparse(self, compile):
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
dense_result = model(input)

from torchao.sparsity.prototype.superblock.blocksparse import (
from torchao.prototype.sparsity.superblock.blocksparse import (
block_sparse_weight,
)

Expand All @@ -167,7 +175,7 @@ def test_sparse(self, compile):
.cuda()
.eval()
)
from torchao.sparsity.prototype.superblock.blocksparse import (
from torchao.prototype.sparsity.superblock.blocksparse import (
blocksparse_int_addmm,
)
from torchao.sparsity.utils import create_block_sparse_tensor
Expand All @@ -189,9 +197,7 @@ def test_sparse(self, compile):

quantize_(
model,
int8_dynamic_activation_int8_weight(
layout=BlockSparseLayout(blocksize=64)
),
int8_dynamic_activation_int8_weight(layout=BlockSparseLayout(blocksize=64)),
)
if compile:
model = torch.compile(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,6 @@

import torch
from torch import nn
from torchao.sparsity.prototype import (
BaseSparsifier,
FakeSparsity,
NearlyDiagonalSparsifier,
WeightNormSparsifier,
)
from torch.nn.utils.parametrize import is_parametrized
from torch.testing._internal.common_pruning import (
ImplementedSparsifier,
Expand All @@ -21,6 +15,12 @@
)

from torch.testing._internal.common_utils import TestCase
from torchao.prototype.sparsity import (
BaseSparsifier,
FakeSparsity,
NearlyDiagonalSparsifier,
WeightNormSparsifier,
)

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
Expand Down Expand Up @@ -486,5 +486,6 @@ def _verify_nearliness(self, mask: torch.Tensor, nearliness: int):
else:
assert mask[row, col] == 0


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
import unittest

import torch
from torchao.sparsity.prototype.sparsifier.utils import (
fqn_to_module,
get_arg_info_from_tensor_fqn,
module_to_fqn,
)

from torch.testing._internal.common_quantization import (
ConvBnReLUModel,
Expand All @@ -18,6 +13,11 @@
TwoLayerLinearModel,
)
from torch.testing._internal.common_utils import TestCase
from torchao.prototype.sparsity.sparsifier.utils import (
fqn_to_module,
get_arg_info_from_tensor_fqn,
module_to_fqn,
)

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,6 @@

import torch
from torch import nn
from torchao.sparsity.prototype.pruner import (
BaseStructuredSparsifier,
FakeStructuredSparsity,
FPGMPruner,
LSTMSaliencyPruner,
SaliencyPruner,
)
from torch.nn.utils import parametrize
from torch.testing._internal.common_pruning import (
Conv2dActivation,
Expand All @@ -32,6 +25,13 @@
)

from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
from torchao.prototype.sparsity.pruner import (
BaseStructuredSparsifier,
FakeStructuredSparsity,
FPGMPruner,
LSTMSaliencyPruner,
SaliencyPruner,
)


logging.basicConfig(
Expand Down Expand Up @@ -1093,5 +1093,6 @@ def test_update_mask(self):
expected_conv1, expected_conv2, device
)


if __name__ == "__main__":
unittest.main()
20 changes: 20 additions & 0 deletions torchao/prototype/sparsity/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Sparsifier
# Scheduler
from torchao.prototype.sparsity.scheduler.base_scheduler import BaseScheduler
from torchao.prototype.sparsity.scheduler.cubic_scheduler import CubicSL
from torchao.prototype.sparsity.scheduler.lambda_scheduler import LambdaSL
from torchao.prototype.sparsity.sparsifier.base_sparsifier import BaseSparsifier
from torchao.prototype.sparsity.sparsifier.nearly_diagonal_sparsifier import (
NearlyDiagonalSparsifier,
)

# Parametrizations
from torchao.prototype.sparsity.sparsifier.utils import (
FakeSparsity,
fqn_to_module,
get_arg_info_from_tensor_fqn,
module_to_fqn,
)
from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import (
WeightNormSparsifier,
)
93 changes: 93 additions & 0 deletions torchao/prototype/sparsity/pruner/FPGM_pruner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Callable, Optional, Union

import torch

from .base_structured_sparsifier import BaseStructuredSparsifier

__all__ = ["FPGMPruner"]


class FPGMPruner(BaseStructuredSparsifier):
r"""Filter Pruning via Geometric Median (FPGM) Structured Pruner
This sparsifier prune fliter (row) in a tensor according to distances among filters according to
`Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`_.
This sparsifier is controlled by three variables:
1. `sparsity_level` defines the number of filters (rows) that are zeroed-out.
2. `dist` defines the distance measurement type. Default: 3 (L2 distance).
Available options are: [1, 2, (custom callable distance function)].
Note::
Inputs should be a 4D convolutional tensor of shape (N, C, H, W).
- N: output channels size
- C: input channels size
- H: height of kernel
- W: width of kernel
"""

def __init__(
self, sparsity_level: float = 0.5, dist: Optional[Union[Callable, int]] = None
):
defaults = {
"sparsity_level": sparsity_level,
}

if dist is None:
dist = 2

if callable(dist):
self.dist_fn = dist
elif dist == 1:
self.dist_fn = lambda x: torch.cdist(x, x, p=1)
elif dist == 2:
self.dist_fn = lambda x: torch.cdist(x, x, p=2)
else:
raise NotImplementedError("Distance function is not yet implemented.")
super().__init__(defaults=defaults)

def _compute_distance(self, t):
r"""Compute distance across all entries in tensor `t` along all dimension
except for the one identified by dim.
Args:
t (torch.Tensor): tensor representing the parameter to prune
Returns:
distance (torch.Tensor): distance computed across filtters
"""
dim = 0 # prune filter (row)

size = t.size(dim)
slc = [slice(None)] * t.dim()

# flatten the tensor along the dimension
t_flatten = [
t[tuple(slc[:dim] + [slice(i, i + 1)] + slc[dim + 1 :])].reshape(-1)
for i in range(size)
]
t_flatten = torch.stack(t_flatten)

# distance measurement
dist_matrix = self.dist_fn(t_flatten)

# more similar with other filter indicates large in the sum of row
distance = torch.sum(torch.abs(dist_matrix), 1)

return distance

def update_mask(self, module, tensor_name, sparsity_level, **kwargs):
tensor_weight = getattr(module, tensor_name)
mask = getattr(module.parametrizations, tensor_name)[0].mask

if sparsity_level <= 0:
mask.data = torch.ones_like(mask).bool()
elif sparsity_level >= 1.0:
mask.data = torch.zeros_like(mask).bool()
else:
distance = self._compute_distance(tensor_weight)

tensor_size = tensor_weight.shape[0] # prune filter (row)
nparams_toprune = round(sparsity_level * tensor_size)
nparams_toprune = min(
max(nparams_toprune, 0), tensor_size
) # clamp to [0, tensor_size]
topk = torch.topk(distance, k=nparams_toprune, largest=False)
mask[topk.indices] = False
File renamed without changes.
8 changes: 8 additions & 0 deletions torchao/prototype/sparsity/pruner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .base_structured_sparsifier import BaseStructuredSparsifier
from .parametrization import (
FakeStructuredSparsity,
BiasHook,
)
from .saliency_pruner import SaliencyPruner
from .lstm_saliency_pruner import LSTMSaliencyPruner
from .FPGM_pruner import FPGMPruner
Loading

0 comments on commit 4ef024c

Please sign in to comment.