-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Model Compression] Pruning Scheduler #4089
Conversation
|
||
|
||
class TaskResult: | ||
def __init__(self, task_id: int, compact_model: Module, compact_model_masks: Dict[str, Dict[str, Tensor]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems compact_model_masks
will not be used at all since it is empty when model has been speedup or it will be same as old_structure_masks if not. I am curious why not delete it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, compact_model_mask
is redundant now. In fact, it can be replaced by the boolean flag if the model is speedup. But for the speedup method, some cases can not apply speedup, and I think these cases-related masks should be in the empty compact_model_masks
in the future. It is a reserved interface for the residue masks after speedup. Maybe users can use this to apply customized speedup.
@@ -27,8 +27,9 @@ def generate_sparsity(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, T | |||
metric = metrics[name] * self._compress_mask(wrapper.weight_mask) | |||
prune_num = int(sparsity_rate * metric.numel()) | |||
if prune_num == 0: | |||
continue | |||
threshold = torch.topk(metric.view(-1), prune_num, largest=False)[0].max() | |||
threshold = metric.min() - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we handle it in this way instead of previous treatment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because even though the mask is all one, maybe it is better to be saved in masks.
@@ -125,54 +128,3 @@ def show_pruned_weights(self, dim: int = 0): | |||
sum_idx.remove(dim) | |||
index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0, as_tuple=False).tolist() | |||
_logger.info(f'simulated prune {wrapper.name} remain/total: {len(index)}/{weight_mask.size(dim)}') | |||
|
|||
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
export is removed, why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because I think the old export_model()
is useless now because we return masks directly in compress()
, maybe let users handle export logic is better, or we can support a new export_model()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After removing the exporting function, what's the new interface between the pruner and speedup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert export_model
, and remove the onnx export.
# pruning model | ||
self.pruner.reset(model, config_list) | ||
self.pruner.load_masks(masks) | ||
compact_model, old_structure_masks = self.pruner.compress() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
old_structure_masks -> pruner_generated_masks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
self._intermidiate_result_dir = Path(self._log_dir_root, 'intermidiate_result') | ||
self._intermidiate_result_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
# save origin data in {log_dir}/intermidiate_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"intermidiate_model"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
intermidiate_data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
Show resolved
Hide resolved
""" | ||
Compare origin model and compact model, return the sparsity of each group mentioned in config list. | ||
A group means all layer mentioned in one config. | ||
i.e., a linear named 'linear1' and its weight size is [100, 100] in origin model, but in compact model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i.e. -> e.g.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_masks: Dict[str, Dict[str, Tensor]], | ||
config_list: List[Dict]) -> Tuple[List[Dict], List[Dict], List[Dict]]: | ||
""" | ||
The current model means the compact model applied the masks. The compact model is the origin model after speed up. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is "current model"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function computes how much the origin model has been compressed in the current state. The current state means compact_model
+ compact_model_masks
(i.e., compact_model_masks
applied on compact_model
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
# NOTE: assume we only mask output, so the mask and bias have a one-to-one correspondence. | ||
# If we support more kind of masks, this place need refactor. | ||
if wrapper.bias_mask is not None and weight_mask.size() == wrapper.bias_mask.size(): | ||
expand_mask['bias_mask'] = weight_mask.clone() | ||
expand_mask['bias'] = weight_mask.clone() | ||
return expand_mask | ||
|
||
def _compress_mask(self, mask: Tensor) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Returns | ||
------- | ||
Tuple[List[Dict], List[Dict], List[Dict]] | ||
(current2origin_sparsity, compact2origin_sparsity, mask2compact_sparsity). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add more description
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@J-shang , please add tests for pruning scheduler, maybe in the next pr if you like |
self.state = {} | ||
|
||
for ref in self.referenced_paths(): | ||
self._reference_counter.setdefault(ref, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is _reference_counter
defined in the whole class the same object with self._reference_counter
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We directly set the reference count to zero here what if ref
is also used by other threads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, all self._reference_counter
in different instances refer to the same dict _reference_counter
in class.
it is a good question, maybe we need to use some kind of lock if support multi-threads. for now, we run pruning in a single thread, and I will add comments in this part to remind us, need to refactor this part if we want to support multi-threads.
task_id | ||
The unique id of task. | ||
compact_model | ||
The unwrapped compact pytorch model after pruning. If the compact model has speed up process during pruning, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has been speeduped during the pruning process
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx, fix it
compact_model | ||
The unwrapped compact pytorch model after pruning. If the compact model has speed up process during pruning, | ||
it will have a smaller structure compare with the model before pruning. | ||
If the compact model do not speed up, it will have the same structure with the model before pruning. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, do not speed up
-> has not been speeduped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
depend on pr #4074