Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compression API supports distributed training #3361

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions paddlenlp/trainer/trainer_compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,14 @@ def init_func(layer):
def _recover_auto_model_forward(self):

def init_func(layer):
if isinstance(layer, self.base_model_class):
if isinstance(
layer, self.base_model_class
if not isinstance(self, paddle.DataParallel) else
self._layers.base_model_class):
layer.forward = layer._ori_forward

for layer in self.children():
for layer in self._layers.children() if isinstance(
self, paddle.DataParallel) else self.children():
layer.apply(init_func)
return self

Expand Down Expand Up @@ -293,7 +297,10 @@ def evaluate(model, data_loader):
if self.custom_dynabert_evaluate is not None:
return self.custom_dynabert_evaluate(model, data_loader)
if isinstance(model, OFA):
class_name = model.model.__class__.__name__
if isinstance(model.model, paddle.DataParallel):
class_name = model.model._layers.__class__.__name__
else:
class_name = model.model.__class__.__name__
else:
class_name = model.__class__.__name__
if "SequenceClassification" in class_name:
Expand Down Expand Up @@ -488,9 +495,12 @@ def _dynabert_export(self, ofa_model):
ofa_model._add_teacher = False
ofa_model, ofa_model.model = _recover_transformer_func(
ofa_model), _recover_transformer_func(ofa_model.model)

ori_num_heads = ofa_model.model.base_model.encoder.layers[
0].self_attn.num_heads
if isinstance(ofa_model.model, paddle.DataParallel):
ori_num_heads = ofa_model.model._layers.base_model.encoder.layers[
0].self_attn.num_heads
else:
ori_num_heads = ofa_model.model.base_model.encoder.layers[
0].self_attn.num_heads
for width_mult in self.args.width_mult_list:
model_dir = os.path.join(self.args.output_dir,
"width_mult_" + str(round(width_mult, 2)))
Expand Down Expand Up @@ -521,8 +531,12 @@ def _dynabert_export(self, ofa_model):
net = paddle.jit.to_static(origin_model_new, input_spec=input_shape)
paddle.jit.save(net, pruned_infer_model_dir)
# Recover num_heads of ofa_model.model
for layer in ofa_model.model.base_model.encoder.layers:
layer.self_attn.num_heads = ori_num_heads
if isinstance(ofa_model.model, paddle.DataParallel):
for layer in ofa_model.model._layers.base_model.encoder.layers:
layer.self_attn.num_heads = ori_num_heads
else:
for layer in ofa_model.model.base_model.encoder.layers:
layer.self_attn.num_heads = ori_num_heads
logger.info("Pruned models have been exported.")


Expand Down