Description
🚀 Tensor and nn.Module Pruning
Tensor method and/or nn
util to sparsify tensors and/or model, according to various pruning techniques in the literature.
Motivation
State-of-the-art deep learning techniques rely on over-parametrized models that are hard to deploy. On the contrary, biological neural networks are known to use efficient sparse connectivity. It's important to identify best techniques to compress models by reducing the number of parameters in them, in order to reduce memory, battery, and hardware consumption without sacrificing accuracy, deploy lightweight models on device, and guarantee privacy with private on-device computation. On the research front, pruning is used to investigate the differences in learning dynamics of over-parametrized and under-parametrized networks, to study the role of lucky sparse subnetworks and initializations ("lottery tickets" [1]), as a destructive neural architecture search technique, and others.
Goal of this feature: harmonizing pruning practices by providing a standard interface in PyTorch.
Target audience: researchers, engineering and product teams.
Pitch
Minimalist API, with deeper flexibility for power-users.
At the tensor level, this could look as follows:
t = torch.randn(3, 2, 4)
# e.g.: randomly mask 6 entries
pruned_t = t.prune(method='random', amount=6)
# e.g.: prune bottom 80% of entries by absolute value
pruned_t = t.prune(method='L1', amount=0.8)
# e.g.: prune 50% of channels along the last dimension by L1 norm
pruned_t = t.prune(method='L1Structured', amount=0.5)
# e.g.: prune 2 channels along the 0th dimension by L2 norm
pruned_t = t.prune(method='L2Structured', amount=2, axis=0)
# e.g.: prune 1 channel along the last dimension by L0 norm
pruned_t = t.prune(method='LnStructured', n=0, amount=1)
A not-in-place .prune
method will return a torch.Tensor
of the same type and size as the one it acts on.
In-place pruning supported via t.prune_(...)
.
At the model level, this will require a bit of thinking but should follow similar API patterns. This is important because not all pruning methods make sense on all parameters in a model (pruning conv kernels != pruning biases != pruning RNNs != pruning in the presence of batch norm, etc.).
First, we should have a sensible, well-documented default behavior for the average-user's API, where a call to net.prune(method='L1', amount=0.8)
defaults to pruning PyTorch "prepackaged" modules (such as linear, conv, and recurrent layers) in some sensible, expected way.
Most power users though would probably want to prune custom layers, or prune different layer types or layers at different depths using different pruning methods or pruning method parameters. This could be specified via a dictionary, which maps parameter names (contained in net.state_dict().keys()
) to a pruning method and its parameters:
{
'features.0.weight' : L1PruningMethod(amount=0.8),
...
}
Similar to the tensor operations, model-level pruning could return a copy of the model, or act on the model in place.
Pruning methods can be used during or post- training; this implementation will be training-loop-agnostic: the user will have to take care of writing their own training loop to decide when to prune and what to do with the pruned object (re-initialize and retrain, finetune, etc.).
Alternatives
Depending on where this will live within the codebase, t.prune(method=method, **kwargs)
could also look like: torch.nn.utils.pruning_function(t, **kwargs)
or torch.nn.utils.pruning_function(module=module, name=param_name, **kwargs)
. I personally would prefer the first option because the kwargs are parameters of the pruning method itself, while t
is the tensor it acts on (some pruning methods will also have to take in some data X
or (X, y)
when they're applied), but the last option is more in line with how, say, weight_norm
is implemented. Perhaps, for Module-level application, following the example of the weight_norm
implementation using hooks will make this more PyTorch-y, but I don't know if we want to sacrifice the ability to act directly on a tensor that is not part of a Module. Would that go into torch.nn.functional
? Open to suggestions here.