@@ -41,6 +41,13 @@ def _load_state(path):
4141 return state
4242
4343
44+ def _strip_postfix (path ):
45+ path , ext = os .path .splitext (path )
46+ assert ext in ['' , '.pdparams' , '.pdopt' , '.pdmodel' ], \
47+ "Unknown postfix {} from weights" .format (ext )
48+ return path
49+
50+
4451def load_params (exe , prog , path ):
4552 """
4653 Load model from the given path.
@@ -50,20 +57,33 @@ def load_params(exe, prog, path):
5057 path (string): URL string or loca model path.
5158 """
5259
53- if not os .path .exists (path ):
60+ path = _strip_postfix (path )
61+ if not (os .path .isdir (path ) or os .path .exists (path + '.pdparams' )):
5462 raise ValueError ("Model pretrain path {} does not "
5563 "exists." .format (path ))
5664
5765 logger .info ('Loading parameters from {}...' .format (path ))
5866
59- def _if_exist (var ):
60- param_exist = os .path .exists (os .path .join (path , var .name ))
61- do_load = param_exist
62- if do_load :
63- logger .debug ('load weight {}' .format (var .name ))
64- return do_load
67+ ignore_set = set ()
68+ state = _load_state (path )
6569
66- fluid .io .load_vars (exe , path , prog , predicate = _if_exist )
70+ # ignore the parameter which mismatch the shape
71+ # between the model and pretrain weight.
72+ all_var_shape = {}
73+ for block in prog .blocks :
74+ for param in block .all_parameters ():
75+ all_var_shape [param .name ] = param .shape
76+ ignore_set .update ([
77+ name for name , shape in all_var_shape .items ()
78+ if name in state and shape != state [name ].shape
79+ ])
80+
81+ if len (ignore_set ) > 0 :
82+ for k in ignore_set :
83+ if k in state :
84+ logger .warning ('variable {} not used' .format (k ))
85+ del state [k ]
86+ fluid .io .set_program_state (prog , state )
6787
6888
6989def save (exe , prog , path ):
@@ -83,6 +103,7 @@ def save(exe, prog, path):
83103def load_and_fusebn (exe , prog , path ):
84104 """
85105 Fuse params of batch norm to scale and bias.
106+
86107 Args:
87108 exe (fluid.Executor): The fluid.Executor object.
88109 prog (fluid.Program): save weight from which Program object.
@@ -104,19 +125,12 @@ def load_and_fusebn(exe, prog, path):
104125 # x is any prefix
105126 mean_variances = set ()
106127 bn_vars = []
107-
108- state = None
109- if os .path .exists (path + '.pdparams' ):
110- state = _load_state (path )
128+ state = _load_state (path )
111129
112130 def check_mean_and_bias (prefix ):
113131 m = prefix + 'mean'
114132 v = prefix + 'variance'
115- if state :
116- return v in state and m in state
117- else :
118- return (os .path .exists (os .path .join (path , m )) and
119- os .path .exists (os .path .join (path , v )))
133+ return v in state and m in state
120134
121135 has_mean_bias = True
122136
@@ -156,16 +170,14 @@ def check_mean_and_bias(prefix):
156170 bn_vars .append (
157171 [scale_name , bias_name , mean_name , variance_name ])
158172
159- if state :
160- fluid .io .set_program_state (prog , state )
161- else :
162- load_params (exe , prog , path )
163173 if not has_mean_bias :
174+ fluid .io .set_program_state (prog , state )
164175 logger .warning (
165176 "There is no paramters of batch norm in model {}. "
166177 "Skip to fuse batch norm. And load paramters done." .format (path ))
167178 return
168179
180+ fluid .load (prog , path , exe )
169181 eps = 1e-5
170182 for names in bn_vars :
171183 scale_name , bias_name , mean_name , var_name = names
0 commit comments