Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor and nn.Module Pruning #20402

Closed
mickypaganini opened this issue May 12, 2019 · 27 comments
Closed

Tensor and nn.Module Pruning #20402

mickypaganini opened this issue May 12, 2019 · 27 comments
Labels
feature A request for a proper, new feature. module: nn Related to torch.nn module: pruning triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mickypaganini
Copy link
Contributor

mickypaganini commented May 12, 2019

🚀 Tensor and nn.Module Pruning

Tensor method and/or nn util to sparsify tensors and/or model, according to various pruning techniques in the literature.

Motivation

State-of-the-art deep learning techniques rely on over-parametrized models that are hard to deploy. On the contrary, biological neural networks are known to use efficient sparse connectivity. It's important to identify best techniques to compress models by reducing the number of parameters in them, in order to reduce memory, battery, and hardware consumption without sacrificing accuracy, deploy lightweight models on device, and guarantee privacy with private on-device computation. On the research front, pruning is used to investigate the differences in learning dynamics of over-parametrized and under-parametrized networks, to study the role of lucky sparse subnetworks and initializations ("lottery tickets" [1]), as a destructive neural architecture search technique, and others.
Goal of this feature: harmonizing pruning practices by providing a standard interface in PyTorch.
Target audience: researchers, engineering and product teams.

Pitch

Minimalist API, with deeper flexibility for power-users.

At the tensor level, this could look as follows:

t = torch.randn(3, 2, 4)
# e.g.: randomly mask 6 entries
pruned_t = t.prune(method='random', amount=6)
# e.g.: prune bottom 80% of entries by absolute value
pruned_t = t.prune(method='L1', amount=0.8)
# e.g.: prune 50% of channels along the last dimension by L1 norm
pruned_t = t.prune(method='L1Structured', amount=0.5)
# e.g.: prune 2 channels along the 0th dimension by L2 norm
pruned_t = t.prune(method='L2Structured', amount=2, axis=0)
# e.g.: prune 1 channel along the last dimension by L0 norm
pruned_t = t.prune(method='LnStructured', n=0, amount=1)

A not-in-place .prune method will return a torch.Tensor of the same type and size as the one it acts on.
In-place pruning supported via t.prune_(...).

At the model level, this will require a bit of thinking but should follow similar API patterns. This is important because not all pruning methods make sense on all parameters in a model (pruning conv kernels != pruning biases != pruning RNNs != pruning in the presence of batch norm, etc.).
First, we should have a sensible, well-documented default behavior for the average-user's API, where a call to net.prune(method='L1', amount=0.8) defaults to pruning PyTorch "prepackaged" modules (such as linear, conv, and recurrent layers) in some sensible, expected way.
Most power users though would probably want to prune custom layers, or prune different layer types or layers at different depths using different pruning methods or pruning method parameters. This could be specified via a dictionary, which maps parameter names (contained in net.state_dict().keys()) to a pruning method and its parameters:

{
    'features.0.weight' : L1PruningMethod(amount=0.8),
    ...
}

Similar to the tensor operations, model-level pruning could return a copy of the model, or act on the model in place.

Pruning methods can be used during or post- training; this implementation will be training-loop-agnostic: the user will have to take care of writing their own training loop to decide when to prune and what to do with the pruned object (re-initialize and retrain, finetune, etc.).

Alternatives

Depending on where this will live within the codebase, t.prune(method=method, **kwargs) could also look like: torch.nn.utils.pruning_function(t, **kwargs) or torch.nn.utils.pruning_function(module=module, name=param_name, **kwargs). I personally would prefer the first option because the kwargs are parameters of the pruning method itself, while t is the tensor it acts on (some pruning methods will also have to take in some data X or (X, y) when they're applied), but the last option is more in line with how, say, weight_norm is implemented. Perhaps, for Module-level application, following the example of the weight_norm implementation using hooks will make this more PyTorch-y, but I don't know if we want to sacrifice the ability to act directly on a tensor that is not part of a Module. Would that go into torch.nn.functional? Open to suggestions here.

cc @albanD @mruberry @jbschlosser

@jeffreyksmithjr jeffreyksmithjr added module: nn Related to torch.nn needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module feature A request for a proper, new feature. labels May 13, 2019
@ezyang
Copy link
Contributor

ezyang commented May 15, 2019

cc @soumith. This sounds generally reasonable, but I am a little wary about signing up to figure out what "reasonable default pruning" behavior is for every module. Maybe this could live out of tree for the short term, and we can assess as people get more experience with what is 100%, obviously useful functionality, and what is a bit more opinion based and changes as research evolves?

@ezyang
Copy link
Contributor

ezyang commented May 16, 2019

Some other comments that I had in private conversation with @mickypaganini:

  • There's probably some API bikeshedding to do on the exact signature of prune (and whether or not it lives as a method or function). All of these possibilities are easy to implement, we just have to decide what to do. One non-obvious sticking point is that writing binding code for operators which take "enums" (like the string L1/L1Structured) is a bit more involved, and so we should also consider just having separate functions/methods for each pruning mechanism (unless there are other design reasons why having it as an enum is a good idea; for example, there are other functions that need to use the same enum, or if users will want to parametrize over pruning strategy uniformly over many calls to the prune method)
  • If you do need to implement some actual CPU/CUDA kernels, it is probably easier to contribute them straight to PyTorch repository
  • For a specialized topic like this, it is very helpful to have someone who steps up to do continual maintenance and bugfixes for it
  • PyTorch prefers to put "obviously good ideas" in the library, and leave people the space to experiment

More motivation for merging directly to PyTorch, from Michela:

The only issue with [an external library for pruning] is that it'll become the (N+1)th solution for pruning models. I don't think anybody needs that. What we need is one centralized canonical way of doing it

Some other reference material:

@mickypaganini
Copy link
Contributor Author

Thanks for summarizing.
Main road block now: deciding where tensor pruning will live.

  1. If we go the nn.utils route: will the signature of nn.utils.mypruningfunction be nn.utils.mypruningfunction(module, name, **kwargs) as it's done in the weight_norm example I linked? That specific .apply method uses hooks, so it assumes that the tensor I want to prune is inside some module which will undergo some forward pass. It won't support independent tensors being pruned outside of this world.
  2. To counter this, should we have a separate tensor method? One should be able to prune any given tensor even if they are not part of a module.

It feels weird to have the thing that controls pruning application in the forward and backward pass also handle the actual pruning logic implementation. Where should I have that logic live?

  • I'm happy to make each pruning technique a separate function/method and avoid string identifiers.
  • I'm also happy to maintain these functionalities long term.
  • We can table default module pruning for each specific nn.Module kind for a later discussion.

@michaelklachko
Copy link

FYI, there's also temporal (lifetime) sparsity for activations:
https://arxiv.org/abs/1409.2752
https://arxiv.org/abs/1903.11257

@michaelklachko
Copy link

Also, typically activation pruning is more aggressive during training (might even be disabled during test in some cases).

@mickypaganini
Copy link
Contributor Author

mickypaganini commented Jun 14, 2019

I started prototyping this idea using simple pruning methods: random and L1-based unstructured pruning, and random and Ln-based structured pruning. These are simple methods that are either random or only depend on the magnitude of the weights. You can check it out on my fork.
[Thanks @michaelklachko for pointers to other pruning method. We have a long list of candidate methods to implement in future iterations, including activation-based methods, weight evolution methods, etc.]

Pruning is currently located under torch/nn/utils, as suggested. The implementation design follows the example of weight_norm.py, using forward pre hooks to reparametrize a parameter in terms of the computed mask and the original parameter's value.

One would interact with pruning in a similar way as with weight norm, i.e. through functions like ln_structured_pruning(module, name, amount, n, axis), that applies pruning, and remove_pruning(module, name) that removes the hook and the reparametrization, and permanently substitutes module[name] with the pruned version of the parameter.

Pruning methods can be composed using a PruningContainer. This enables iterative pruning.

I'm still writing tests to ensure correctness.
Looking forward to your feedback!

@michaelklachko
Copy link

Note that pruning operation is similar to quantization. So perhaps the two can be implemented in a similar manner, or even combined. For example, binarizing ReLU activations, and ternarizing weights can be considered a form of pruning.

@michaelklachko
Copy link

michaelklachko commented Jun 14, 2019

Here's one example of how quantization op can be integrated into a layer. In my opinion, it makes sense to implement weight pruning/quantization a wrapper for a layer (e.g. like weightnorm), however for activations it's probably better to have it as a separate layer type (e.g. like batchnorm).

@soumith
Copy link
Member

soumith commented Jun 20, 2019

@mickypaganini thanks a ton for the proposal, and your WIP branch.

I reviewed the overall (new) design using hooks, and focused more around nn.Module rather than tensors themselves. I think this design is less invasive and makes a lot more sense -- because pruning is primarily targeted for weights anyways. It looks great.

I have some bikeshedding points. I think you'd want to place the code in nn.utils.prune and then remove pruning from all of the functions that you implemented.
For example:

# Current
import torch.nn.utils as utils

utils.ln_structured_pruning(...)
utils.remove_pruning(...)

# New
import torch.nn.utils.prune as prune

prune.ln_structured(...)
prune.remove(....)

For each doc snippet, you probably also want to add Example: apart from Args: and Output:
I'm excited to see this unlock a bunch of pruning research.

@michaelklachko I disagree that quantization should be combined with pruning because:

  • Pruning is experimental in it's research life-cycle, not many think about optimized code for a pruned model for all of the pruning methods that folks are currently doing research on.
  • Quantization on the other hand is extremely mature from the scientific computing perspective and in the neural network land -- there are optimized code and algorithms implemented for scale and shift quantization.
  • So, treating quantization as a subset of pruning in the interface and code-paths will either limit the interface we expand pruning to, because we want it to also run fast OR will limit the quantization to be of limited performance.
    So, I think I prefer quantization to be a separate, fully-fleshed out interface, as is being implemented in Model Quantization for PyTorch (Proposal) #18318 instead of being conflated with more exotic pruning techniques.

@soumith
Copy link
Member

soumith commented Jun 20, 2019

@mickypaganini one thing to keep in mind though is that using forward hooks has it's limitations.
For example, if you want to prune based on post-activation information, it's not very straight-forward to do so.

Example:

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as P

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = P. random_structured(nn.Conv2d(1, 20, 5, 1), 'weight', 0.5, 1)

    def forward(self, x):
        y = F.relu(self.conv1(x))
        z = F.max_pool2d(y, 2, 2)
        return F.log_softmax(z, dim=1)

Here, you cannot prune based on the values of y using the current interface, you can prune only based on the values of x, conv1.weight and conv1.bias and at best conv1's output. y is the value of the output of conv1 after a relu, which isn't materialized after all pre and post hooks of self.conv1 are already executed.
Is this okay, or do you see any pruning methods that cannot be done because of this limitation in the interface?

@michaelklachko
Copy link

michaelklachko commented Jun 20, 2019

@soumith

* Quantization on the other hand is extremely mature

Oh I wish this were true! :) #18318 refers to only post-training quantization, and only deals with 8 bit precision. That's a good start, but it's trivial compared to all the active research on binary and ternary networks. As you go below 4 bits you start seeing accuracy degradation, inversely proportional to the model size. New papers get posted every week on ways to reduce that degradation. Just in the last six months I've seen ~10 papers claiming new state of the art :) Note that latest Nvidia cards offer native support for 1 bit ops. There are far more methods to perform quantization than to perform pruning :)

As I'm typing this, I'm training a convnet for a mixed signal chip we built, where the weights are analog, but activations have to be stored in digital memory, and therefore must be quantized. Note that in modern convnets activations consume a lot more memory than weights, so compressing them is a lot more effective from the point of view of memory constrained devices. So I'm currently exploring ways to binarize activations. How is this relevant to this ticket? As a separate effort, I've also experimented with several activation pruning methods to reduce power consumption, and I noticed that if I apply a temporal sparsity constraint on activations, it becomes easier to find the optimal clipping threshold for binarization. Basically, it's easier to binarize sparse activations, so the pruning can be considered as a first stage in the quantization process.

Having said that, I actually agree that pruning should be kept separately from quantization, not because one is more mature than other (both are still very experimental), but because they are sufficiently distinct concepts, and it makes sense to apply them as separate steps (which especially helps during debugging). My point was more towards treating them as similar type of operations. In fact, if someone is interested in doing quantization, they should be also looking at pruning methods, and vice versa. How you do one affects the effectiveness of doing the other. The ultimate goal is to compress the model.

TLDR:

  1. Please don't ignore activations (rather than focusing only on weights), because their size is far more important for memory constrained devices.
  2. Quantization and Pruning don't have to be combined, but should be treated as similar operations, with similar interfaces (when applied to weights/activations, post/during training). They both are ways to perform model compression.

@mickypaganini
Copy link
Contributor Author

mickypaganini commented Jun 20, 2019

@soumith thanks for the positive feedback!

OK on prune.ln_structured instead of utils.ln_structured_pruning. I agree.

On your second point, perhaps a naive solution:

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as P

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)

    def forward(self, x, return_hidden=False):
        y = F.relu(self.conv1(x))
        out = F.log_softmax(F.max_pool2d(y, 2, 2),  dim=1)
        if return_hidden:
            return y, out
        else:
            return out

Then, similar to other pruning methods provided so far, we'd have to implement an activation-based pruning method like activation_based_pruning(module, name, X, amount) with module=net.conv1, name='weight', X=<some input data>, that would run the forward function with return_hidden=True to get the activations y, and compute a structured mask on the layer. Unlike the methods currently implemented in my work-in-progress, this family of pruning functions would require X as an argument.

An alternative would be to have activation_based_pruning(module, name, activations, amount), where activations is the activated output from that hidden layer, presumably computed in some previous forward pass.

Or we could even join the two and compute the activations from data X, unless precomputed activations is passed in. If both are passed, maybe warn but give precedence to the precompute activations? Not sure about this one...

Option 1 might be the least error prone, but the least flexible.
In any case, these all allow you to prune the layer based on its activations on some representative data sample. Once one of these pruning functions is called, the mask is computed and stored in the forward pre hooks, so it's applied every time forward will be called.

If instead you were referring to some sort of real-time pruning with scope limited to each individual forward call, then that would require rewriting the forward and computing the mask in some functional way, instead of using forward hooks. That only makes sense if one wants to temporarily prune the activations before passing them on to the next layer, instead of pruning units in the layer itself.

@mickypaganini
Copy link
Contributor Author

mickypaganini commented Jun 21, 2019

By the way, even the last case of ephemeral activation pruning can already be supported by the pruning module as is. It will require interacting with the <...>PruningMethod objects directly. Example:

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as P

class Net(nn.Module):
     def __init__(self):
         super(Net, self).__init__()
         self.conv1 = nn.Conv2d(1, 20, 5, 1)
         self.pruning1 = P.LnStructuredPruningMethod(amount=0.5, n=2, axis=-1)
     def forward(self, x):
         y = F.relu(self.conv1(x))
         mask = self.pruning1.compute_mask(y, default_mask=torch.ones_like(y))
         masked_y = y * mask.to(dtype=y.dtype)
         out = F.log_softmax(F.max_pool2d(masked_y, 2, 2),  dim=1)
         return out

This uses no hooks.

@raghuramank100
Copy link
Contributor

@soumith

* Quantization on the other hand is extremely mature

Oh I wish this were true! :) #18318 refers to only post-training quantization, and only deals with 8 bit precision. That's a good start, but it's trivial compared to all the active research on binary and ternary networks. As you go below 4 bits you start seeing accuracy degradation, inversely proportional to the model size. New papers get posted every week on ways to reduce that degradation. Just in the last six months I've seen ~10 papers claiming new state of the art :) Note that latest Nvidia cards offer native support for 1 bit ops. There are far more methods to perform quantization than to perform pruning :)

As I'm typing this, I'm training a convnet for a mixed signal chip we built, where the weights are analog, but activations have to be stored in digital memory, and therefore must be quantized. Note that in modern convnets activations consume a lot more memory than weights, so compressing them is a lot more effective from the point of view of memory constrained devices. So I'm currently exploring ways to binarize activations. How is this relevant to this ticket? As a separate effort, I've also experimented with several activation pruning methods to reduce power consumption, and I noticed that if I apply a temporal sparsity constraint on activations, it becomes easier to find the optimal clipping threshold for binarization. Basically, it's easier to binarize sparse activations, so the pruning can be considered as a first stage in the quantization process.

Having said that, I actually agree that pruning should be kept separately from quantization, not because one is more mature than other (both are still very experimental), but because they are sufficiently distinct concepts, and it makes sense to apply them as separate steps (which especially helps during debugging). My point was more towards treating them as similar type of operations. In fact, if someone is interested in doing quantization, they should be also looking at pruning methods, and vice versa. How you do one affects the effectiveness of doing the other. The ultimate goal is to compress the model.

TLDR:

  1. Please don't ignore activations (rather than focusing only on weights), because their size is far more important for memory constrained devices.
  2. Quantization and Pruning don't have to be combined, but should be treated as similar operations, with similar interfaces (when applied to weights/activations, post/during training). They both are ways to perform model compression.

I agree, the goal is to unify interfaces for quantization and pruning at the level of APIs. Great point about activations.

The key difference is that there is a lot more hardware support for 8 bit (and now lower bitwidth) quantization, allowing for faster kernel implementations that run on a wide range of devices. Sparsity support is still limited to hardware accelerators and there seems to be a big distinction between sparsifying weights only (can be packed, treated like a constant for inference) and sparsifying activations (dynamic, requires specialized hw).

@michaelklachko
Copy link

@raghuramank100

there seems to be a big distinction between sparsifying weights only (can be packed, treated like a constant for inference) and sparsifying activations (dynamic, requires specialized hw).

Just to clarify, there are three potential efficiency benefits from pruning (or quantization): reducing model size on disk, reducing memory consumption (during training or inference), and saving computation. Out of these, memory consumption is arguably the biggest concern. Specifically, memory consumption during training is arguably the biggest concern for researchers (the Pytorch core userbase). That's why methods like gradient checkpointing are popular when working with large models. Note that during training there's no difference between packing/unpacking weights and packing/unpacking activations.

@mickypaganini
Sometimes we want to gradually increase the amount of pruning done during training (or perhaps change some other pruning parameters). For example, in
self.pruning1 = P.LnStructuredPruningMethod(amount=0.5, n=2, axis=-1)
we might want to change amount according to some schedule or even make it a trainable parameter.
How would you deal with such a scenario?

@mickypaganini
Copy link
Contributor Author

@michaelklachko: a quick solution would be to move the instantiation from the __init__ to the forward, so that you can pass in new pruning parameters at each forward call.

@michaelklachko
Copy link

What do you mean? I'd probably do it like this, but I'm not sure if this is considered a good practice:

class Net(nn.Module):
     def __init__(self):
         super(Net, self).__init__()
         self.conv1 = nn.Conv2d(1, 20, 5, 1)
         self.pruning1 = P.LnStructuredPruningMethod(n=2, axis=-1)
         #or if trainable: self.amount = torch.nn.Parameter(...)
     def forward(self, x, amount=0.5):
         y = F.relu(self.conv1(x))
         mask = self.pruning1.compute_mask(y, default_mask=torch.ones_like(y), amount=amount)  # or amount=self.amount
         masked_y = y * mask.to(dtype=y.dtype)
         out = F.log_softmax(F.max_pool2d(masked_y, 2, 2),  dim=1)
         return out

@YuJiang01
Copy link

@mickypaganini , This is a very useful feature. Just as discussed, some issues we may need to consider.

  • In this case, local pruning can be very nicely handle, but how about global pruning, which is also very common
  • for pruning existed models, for example VGG, resNet those are already implemented in torchvision. Under this framework, will it have to re-implement all models?

@soumith
Copy link
Member

soumith commented Jun 24, 2019

for pruning existing models, if it is via local pruning, then I dont think you need to reimplement them or modify their code, because you can simply loop over the modules and set the pruning hooks on them.

If it's activation pruning or global pruning, yes you have to change the code of those models, and we shouldn't strive to find a solution where you dont touch the source code of those models.
Let's operate in a working state of -- copy-and-modify is perfectly acceptable.

@karanchahal
Copy link

I have a basic question about pruning. After applying a fine grained pruning methodology that masks out weights randomly, how will speed ups be achieved for these sparse weights in Pytorch ? Are there plans to release a SparseConv2d, SparseLinear etc ?

@soumith
Copy link
Member

soumith commented Jul 11, 2019

pruning is at a stage where we aren't looking at speedups and performance. it's still researchy.

@karanchahal
Copy link

Alright thank you. I was wondering then what is the use of Sparse Tensor support ? How can one leverage them for pruning ? I ask this as it seems like they are a natural fit for representing pruned tensors.

@michaelklachko
Copy link

Pytorch library for sparse NN training: https://github.com/TimDettmers/sparse_learning

@karanchahal
Copy link

Thank you for this, I looked at the code and the author is masking out the weights with a boolean mask matrix. That is fine, however I wondering if pytorch will support pruning by using SparseTensors, that is, not computing the unnecessary multiplication with zeros.

@ThomAub
Copy link

ThomAub commented Nov 7, 2019

Hello,
This is a great feature and I think it's great to have a normalised api for pruning, quantisation, etc.

If it's activation pruning or global pruning, yes you have to change the code of those models, and we shouldn't strive to find a solution where you dont touch the source code of those models.

To handle global structure pruning it should be able to change the dimension of the some tensors and to propagate the change to the rest of the graph. Do you think this should be a feature of this work or it's out of scope ?

facebook-github-bot pushed a commit that referenced this issue Nov 9, 2019
Summary:
Provides implementation for feature request issue #20402.

Adds pruning functionalities (structured and unstructured, local and global, as well as pruning from user-provided mask).

Associated tutorial here: pytorch/tutorials#605

cc: soumith
Pull Request resolved: #24076

Differential Revision: D18400431

Pulled By: mickypaganini

fbshipit-source-id: a97bd6ca61f8600ae411da9ff6533c232aae1a51
@gchanan gchanan removed the needs research We need to decide whether or not this merits inclusion, based on research world label Feb 25, 2020
@gchanan
Copy link
Contributor

gchanan commented Feb 25, 2020

@mickypaganini should this be closed as already implemented? Or is there more to do?

@mickypaganini
Copy link
Contributor Author

Yes, this can be closed now. This was implemented and merged. Upcoming changes and improvements to the pruning module are being tracked in other issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: nn Related to torch.nn module: pruning triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

10 participants