-
Notifications
You must be signed in to change notification settings - Fork 23.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor and nn.Module Pruning #20402
Comments
cc @soumith. This sounds generally reasonable, but I am a little wary about signing up to figure out what "reasonable default pruning" behavior is for every module. Maybe this could live out of tree for the short term, and we can assess as people get more experience with what is 100%, obviously useful functionality, and what is a bit more opinion based and changes as research evolves? |
Some other comments that I had in private conversation with @mickypaganini:
More motivation for merging directly to PyTorch, from Michela:
Some other reference material: |
Thanks for summarizing.
It feels weird to have the thing that controls pruning application in the forward and backward pass also handle the actual pruning logic implementation. Where should I have that logic live?
|
FYI, there's also temporal (lifetime) sparsity for activations: |
Also, typically activation pruning is more aggressive during training (might even be disabled during test in some cases). |
I started prototyping this idea using simple pruning methods: random and L1-based unstructured pruning, and random and Ln-based structured pruning. These are simple methods that are either random or only depend on the magnitude of the weights. You can check it out on my fork. Pruning is currently located under One would interact with pruning in a similar way as with weight norm, i.e. through functions like Pruning methods can be composed using a I'm still writing tests to ensure correctness. |
Note that pruning operation is similar to quantization. So perhaps the two can be implemented in a similar manner, or even combined. For example, binarizing ReLU activations, and ternarizing weights can be considered a form of pruning. |
Here's one example of how quantization op can be integrated into a layer. In my opinion, it makes sense to implement weight pruning/quantization a wrapper for a layer (e.g. like weightnorm), however for activations it's probably better to have it as a separate layer type (e.g. like batchnorm). |
@mickypaganini thanks a ton for the proposal, and your WIP branch. I reviewed the overall (new) design using hooks, and focused more around nn.Module rather than tensors themselves. I think this design is less invasive and makes a lot more sense -- because pruning is primarily targeted for weights anyways. It looks great. I have some bikeshedding points. I think you'd want to place the code in # Current
import torch.nn.utils as utils
utils.ln_structured_pruning(...)
utils.remove_pruning(...)
# New
import torch.nn.utils.prune as prune
prune.ln_structured(...)
prune.remove(....) For each doc snippet, you probably also want to add @michaelklachko I disagree that quantization should be combined with pruning because:
|
@mickypaganini one thing to keep in mind though is that using forward hooks has it's limitations. Example: import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as P
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = P. random_structured(nn.Conv2d(1, 20, 5, 1), 'weight', 0.5, 1)
def forward(self, x):
y = F.relu(self.conv1(x))
z = F.max_pool2d(y, 2, 2)
return F.log_softmax(z, dim=1) Here, you cannot prune based on the values of |
Oh I wish this were true! :) #18318 refers to only post-training quantization, and only deals with 8 bit precision. That's a good start, but it's trivial compared to all the active research on binary and ternary networks. As you go below 4 bits you start seeing accuracy degradation, inversely proportional to the model size. New papers get posted every week on ways to reduce that degradation. Just in the last six months I've seen ~10 papers claiming new state of the art :) Note that latest Nvidia cards offer native support for 1 bit ops. There are far more methods to perform quantization than to perform pruning :) As I'm typing this, I'm training a convnet for a mixed signal chip we built, where the weights are analog, but activations have to be stored in digital memory, and therefore must be quantized. Note that in modern convnets activations consume a lot more memory than weights, so compressing them is a lot more effective from the point of view of memory constrained devices. So I'm currently exploring ways to binarize activations. How is this relevant to this ticket? As a separate effort, I've also experimented with several activation pruning methods to reduce power consumption, and I noticed that if I apply a temporal sparsity constraint on activations, it becomes easier to find the optimal clipping threshold for binarization. Basically, it's easier to binarize sparse activations, so the pruning can be considered as a first stage in the quantization process. Having said that, I actually agree that pruning should be kept separately from quantization, not because one is more mature than other (both are still very experimental), but because they are sufficiently distinct concepts, and it makes sense to apply them as separate steps (which especially helps during debugging). My point was more towards treating them as similar type of operations. In fact, if someone is interested in doing quantization, they should be also looking at pruning methods, and vice versa. How you do one affects the effectiveness of doing the other. The ultimate goal is to compress the model. TLDR:
|
@soumith thanks for the positive feedback! OK on On your second point, perhaps a naive solution: import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as P
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
def forward(self, x, return_hidden=False):
y = F.relu(self.conv1(x))
out = F.log_softmax(F.max_pool2d(y, 2, 2), dim=1)
if return_hidden:
return y, out
else:
return out Then, similar to other pruning methods provided so far, we'd have to implement an activation-based pruning method like An alternative would be to have Or we could even join the two and compute the activations from data Option 1 might be the least error prone, but the least flexible. If instead you were referring to some sort of real-time pruning with scope limited to each individual forward call, then that would require rewriting the |
By the way, even the last case of ephemeral activation pruning can already be supported by the pruning module as is. It will require interacting with the import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as P
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.pruning1 = P.LnStructuredPruningMethod(amount=0.5, n=2, axis=-1)
def forward(self, x):
y = F.relu(self.conv1(x))
mask = self.pruning1.compute_mask(y, default_mask=torch.ones_like(y))
masked_y = y * mask.to(dtype=y.dtype)
out = F.log_softmax(F.max_pool2d(masked_y, 2, 2), dim=1)
return out This uses no hooks. |
I agree, the goal is to unify interfaces for quantization and pruning at the level of APIs. Great point about activations. The key difference is that there is a lot more hardware support for 8 bit (and now lower bitwidth) quantization, allowing for faster kernel implementations that run on a wide range of devices. Sparsity support is still limited to hardware accelerators and there seems to be a big distinction between sparsifying weights only (can be packed, treated like a constant for inference) and sparsifying activations (dynamic, requires specialized hw). |
Just to clarify, there are three potential efficiency benefits from pruning (or quantization): reducing model size on disk, reducing memory consumption (during training or inference), and saving computation. Out of these, memory consumption is arguably the biggest concern. Specifically, memory consumption during training is arguably the biggest concern for researchers (the Pytorch core userbase). That's why methods like gradient checkpointing are popular when working with large models. Note that during training there's no difference between packing/unpacking weights and packing/unpacking activations. @mickypaganini |
@michaelklachko: a quick solution would be to move the instantiation from the |
What do you mean? I'd probably do it like this, but I'm not sure if this is considered a good practice:
|
@mickypaganini , This is a very useful feature. Just as discussed, some issues we may need to consider.
|
for pruning existing models, if it is via local pruning, then I dont think you need to reimplement them or modify their code, because you can simply loop over the modules and set the pruning hooks on them. If it's activation pruning or global pruning, yes you have to change the code of those models, and we shouldn't strive to find a solution where you dont touch the source code of those models. |
I have a basic question about pruning. After applying a fine grained pruning methodology that masks out weights randomly, how will speed ups be achieved for these sparse weights in Pytorch ? Are there plans to release a SparseConv2d, SparseLinear etc ? |
pruning is at a stage where we aren't looking at speedups and performance. it's still researchy. |
Alright thank you. I was wondering then what is the use of Sparse Tensor support ? How can one leverage them for pruning ? I ask this as it seems like they are a natural fit for representing pruned tensors. |
Pytorch library for sparse NN training: https://github.com/TimDettmers/sparse_learning |
Thank you for this, I looked at the code and the author is masking out the weights with a boolean mask matrix. That is fine, however I wondering if pytorch will support pruning by using SparseTensors, that is, not computing the unnecessary multiplication with zeros. |
Hello,
To handle global structure pruning it should be able to change the dimension of the some tensors and to propagate the change to the rest of the graph. Do you think this should be a feature of this work or it's out of scope ? |
Summary: Provides implementation for feature request issue #20402. Adds pruning functionalities (structured and unstructured, local and global, as well as pruning from user-provided mask). Associated tutorial here: pytorch/tutorials#605 cc: soumith Pull Request resolved: #24076 Differential Revision: D18400431 Pulled By: mickypaganini fbshipit-source-id: a97bd6ca61f8600ae411da9ff6533c232aae1a51
@mickypaganini should this be closed as already implemented? Or is there more to do? |
Yes, this can be closed now. This was implemented and merged. Upcoming changes and improvements to the pruning module are being tracked in other issues. |
🚀 Tensor and nn.Module Pruning
Tensor method and/or
nn
util to sparsify tensors and/or model, according to various pruning techniques in the literature.Motivation
State-of-the-art deep learning techniques rely on over-parametrized models that are hard to deploy. On the contrary, biological neural networks are known to use efficient sparse connectivity. It's important to identify best techniques to compress models by reducing the number of parameters in them, in order to reduce memory, battery, and hardware consumption without sacrificing accuracy, deploy lightweight models on device, and guarantee privacy with private on-device computation. On the research front, pruning is used to investigate the differences in learning dynamics of over-parametrized and under-parametrized networks, to study the role of lucky sparse subnetworks and initializations ("lottery tickets" [1]), as a destructive neural architecture search technique, and others.
Goal of this feature: harmonizing pruning practices by providing a standard interface in PyTorch.
Target audience: researchers, engineering and product teams.
Pitch
Minimalist API, with deeper flexibility for power-users.
At the tensor level, this could look as follows:
A not-in-place
.prune
method will return atorch.Tensor
of the same type and size as the one it acts on.In-place pruning supported via
t.prune_(...)
.At the model level, this will require a bit of thinking but should follow similar API patterns. This is important because not all pruning methods make sense on all parameters in a model (pruning conv kernels != pruning biases != pruning RNNs != pruning in the presence of batch norm, etc.).
First, we should have a sensible, well-documented default behavior for the average-user's API, where a call to
net.prune(method='L1', amount=0.8)
defaults to pruning PyTorch "prepackaged" modules (such as linear, conv, and recurrent layers) in some sensible, expected way.Most power users though would probably want to prune custom layers, or prune different layer types or layers at different depths using different pruning methods or pruning method parameters. This could be specified via a dictionary, which maps parameter names (contained in
net.state_dict().keys()
) to a pruning method and its parameters:Similar to the tensor operations, model-level pruning could return a copy of the model, or act on the model in place.
Pruning methods can be used during or post- training; this implementation will be training-loop-agnostic: the user will have to take care of writing their own training loop to decide when to prune and what to do with the pruned object (re-initialize and retrain, finetune, etc.).
Alternatives
Depending on where this will live within the codebase,
t.prune(method=method, **kwargs)
could also look like:torch.nn.utils.pruning_function(t, **kwargs)
ortorch.nn.utils.pruning_function(module=module, name=param_name, **kwargs)
. I personally would prefer the first option because the kwargs are parameters of the pruning method itself, whilet
is the tensor it acts on (some pruning methods will also have to take in some dataX
or(X, y)
when they're applied), but the last option is more in line with how, say,weight_norm
is implemented. Perhaps, for Module-level application, following the example of theweight_norm
implementation using hooks will make this more PyTorch-y, but I don't know if we want to sacrifice the ability to act directly on a tensor that is not part of a Module. Would that go intotorch.nn.functional
? Open to suggestions here.cc @albanD @mruberry @jbschlosser
The text was updated successfully, but these errors were encountered: