18
18
# limitations under the License.
19
19
20
20
import re
21
- from collections import UserDict , defaultdict
21
+ from collections import UserDict
22
22
23
23
import numpy as np
24
- import yaml
25
24
26
- from ...config import WeightPruningConfig as WeightPruningConf
27
-
28
- try :
29
- from neural_compressor .conf .config import Pruner
30
- from neural_compressor .conf .dotdict import DotDict
31
- from neural_compressor .utils import logger
32
-
33
- from ...conf .config import PrunerV2
34
- from ...conf .pythonic_config import WeightPruningConfig
35
- from ...utils .utility import LazyImport
36
-
37
- torch = LazyImport ("torch" )
38
- nn = LazyImport ("torch.nn" )
39
- F = LazyImport ("torch.nn.functional" )
40
- tf = LazyImport ("tensorflow" )
41
- except :
42
- import logging
43
-
44
- import tensorflow as tf
45
- import torch
46
- import torch .nn as nn
47
- import torch .nn .functional as F
48
-
49
- from .dot_dict import DotDict # #TODO
50
-
51
- logger = logging .getLogger (__name__ )
52
- from .schema_check import PrunerV2
53
-
54
- class WeightPruningConfig :
55
- """Similar to torch optimizer's interface."""
56
-
57
- def __init__ (
58
- self ,
59
- pruning_configs = [{}], ##empty dict will use global values
60
- target_sparsity = 0.9 ,
61
- pruning_type = "snip_momentum" ,
62
- pattern = "4x1" ,
63
- op_names = [],
64
- excluded_op_names = [],
65
- start_step = 0 ,
66
- end_step = 0 ,
67
- pruning_scope = "global" ,
68
- pruning_frequency = 1 ,
69
- min_sparsity_ratio_per_op = 0.0 ,
70
- max_sparsity_ratio_per_op = 0.98 ,
71
- sparsity_decay_type = "exp" ,
72
- pruning_op_types = ["Conv" , "Linear" ],
73
- ** kwargs ,
74
- ):
75
- """Init a WeightPruningConfig object."""
76
- self .pruning_configs = pruning_configs
77
- self ._weight_compression = DotDict (
78
- {
79
- "target_sparsity" : target_sparsity ,
80
- "pruning_type" : pruning_type ,
81
- "pattern" : pattern ,
82
- "op_names" : op_names ,
83
- "excluded_op_names" : excluded_op_names , ##global only
84
- "start_step" : start_step ,
85
- "end_step" : end_step ,
86
- "pruning_scope" : pruning_scope ,
87
- "pruning_frequency" : pruning_frequency ,
88
- "min_sparsity_ratio_per_op" : min_sparsity_ratio_per_op ,
89
- "max_sparsity_ratio_per_op" : max_sparsity_ratio_per_op ,
90
- "sparsity_decay_type" : sparsity_decay_type ,
91
- "pruning_op_types" : pruning_op_types ,
92
- }
93
- )
94
- self ._weight_compression .update (kwargs )
25
+ from neural_compressor .utils import logger
26
+ from neural_compressor .utils .utility import DotDict
95
27
96
- @property
97
- def weight_compression (self ):
98
- """Get weight_compression."""
99
- return self ._weight_compression
28
+ from ...config import WeightPruningConfig as WeightPruningConf
29
+ from ...utils .utility import LazyImport
100
30
101
- @ weight_compression . setter
102
- def weight_compression ( self , weight_compression ):
103
- """Set weight_compression."""
104
- self . _weight_compression = weight_compression
31
+ torch = LazyImport ( "torch" )
32
+ nn = LazyImport ( "torch.nn" )
33
+ F = LazyImport ( "torch.nn.functional" )
34
+ tf = LazyImport ( "tensorflow" )
105
35
106
36
107
37
def get_sparsity_ratio (pruners , model ):
@@ -423,14 +353,10 @@ def check_key_validity_prunerv2(template_config, usr_cfg_dict):
423
353
for obj in user_config :
424
354
if isinstance (obj , dict ):
425
355
check_key_validity_dict (template_config , obj )
426
- elif isinstance (obj , PrunerV2 ):
427
- check_key_validity_prunerv2 (template_config , obj )
428
356
429
357
# single pruner, weightconfig or yaml
430
358
elif isinstance (user_config , dict ):
431
359
check_key_validity_dict (template_config , user_config )
432
- elif isinstance (user_config , PrunerV2 ):
433
- check_key_validity_prunerv2 (template_config , user_config )
434
360
return
435
361
436
362
@@ -470,7 +396,7 @@ def process_and_check_config(val):
470
396
default_config .update (default_global_config )
471
397
default_config .update (default_local_config )
472
398
default_config .update (params_default_config )
473
- if isinstance (val , WeightPruningConfig ) or isinstance ( val , WeightPruningConf ):
399
+ if isinstance (val , WeightPruningConf ):
474
400
global_configs = val .weight_compression
475
401
pruning_configs = val .pruning_configs
476
402
check_key_validity (default_config , pruning_configs )
@@ -494,21 +420,7 @@ def process_config(config):
494
420
Returns:
495
421
A config dict object.
496
422
"""
497
- if isinstance (config , str ):
498
- try :
499
- with open (config , "r" ) as f :
500
- content = f .read ()
501
- val = yaml .safe_load (content )
502
- ##schema.validate(val)
503
- return process_and_check_config (val )
504
- except FileNotFoundError as f :
505
- logger .error ("{}." .format (f ))
506
- raise RuntimeError ("The yaml file is not exist. Please check the file name or path." )
507
- except Exception as e :
508
- logger .error ("{}." .format (e ))
509
- raise RuntimeError ("The yaml file format is not correct. Please refer to document." )
510
-
511
- if isinstance (config , WeightPruningConfig ) or isinstance (config , WeightPruningConf ):
423
+ if isinstance (config , WeightPruningConf ):
512
424
return process_and_check_config (config )
513
425
else :
514
426
assert False , f"not supported type { config } "
@@ -618,25 +530,6 @@ def parse_to_prune_tf(config, model):
618
530
return new_modules
619
531
620
532
621
- def generate_pruner_config (info ):
622
- """Generate pruner config object from prune information.
623
-
624
- Args:
625
- info: A dotdict that saves prune information.
626
-
627
- Returns:
628
- pruner: A pruner config object.
629
- """
630
- return Pruner (
631
- initial_sparsity = 0 ,
632
- method = info .method ,
633
- target_sparsity = info .target_sparsity ,
634
- start_epoch = info .start_step ,
635
- end_epoch = info .end_step ,
636
- update_frequency = info .pruning_frequency ,
637
- )
638
-
639
-
640
533
def get_layers (model ):
641
534
"""Get each layer's name and its module.
642
535
0 commit comments