Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hf bart training and inference #316

Merged
merged 16 commits into from
Jun 2, 2022
410 changes: 410 additions & 0 deletions examples/training/huggingface/bart/summarization/ls_bart_model.py

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2021 The LightSeq Team
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

THIS_DIR=$(dirname $(readlink -f $0))

export TASK_NAME=summarization

python3 -m torch.distributed.launch \
--nproc_per_node=1 \
$THIS_DIR/run_summarization.py \
--model_name_or_path facebook/bart-base \
--do_train \
--do_eval \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--output_dir /tmp/$TASK_NAME \
--max_source_length 128 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--overwrite_output_dir \
--seed 1234 \
--logging_steps 10 \
--fp16 \
--predict_with_generate
115 changes: 115 additions & 0 deletions examples/training/huggingface/bert/ls_hf_transformer_layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
import torch.nn as nn
from lightseq.training.ops.pytorch.quantization import qat_mode, disable_quant
from lightseq.training.ops.pytorch.torch_transformer_layers import BertEmbeddingLayer
from transformers import (
BertForSequenceClassification,
BertPreTrainedModel,
BertLayer,
BertLMHeadModel,
BertForMaskedLM,
BertForNextSentencePrediction,
BertForMultipleChoice,
BertForTokenClassification,
BertForQuestionAnswering,
)


def get_hf_bert_enc_layer_params(layer):
Expand Down Expand Up @@ -114,3 +126,106 @@ def gen_bert_enc_config(training_args, config):
model.bert.encoder.layer[i].apply(qat_mode)
else:
model.bert.encoder.layer[i].apply(disable_quant)


def hf_state_dict(model):
"""
Args:
model: huggingface model replaced with lightseq layer
Returns:
Dict: The huggingface state dict
"""

def unwrap_model(model):
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model

def inject_hf_layer(config, hf_layer, ls_layer):
for layer_id in range(config.num_hidden_layers):
weight, bias = ls_layer[layer_id].params_dict()
layer = hf_layer[layer_id]
layer.attention.self.query.weight.data.copy_(weight["self_attn_q_proj"])
layer.attention.self.query.bias.data.copy_(bias["self_attn_q_proj"])
layer.attention.self.key.weight.data.copy_(weight["self_attn_k_proj"])
layer.attention.self.key.bias.data.copy_(bias["self_attn_k_proj"])
layer.attention.self.value.weight.data.copy_(weight["self_attn_v_proj"])
layer.attention.self.value.bias.data.copy_(bias["self_attn_v_proj"])
layer.attention.output.dense.weight.data.copy_(weight["self_attn_out_proj"])
layer.attention.output.dense.bias.data.copy_(bias["self_attn_out_proj"])
layer.attention.output.LayerNorm.weight.data.copy_(
weight["self_attn_layer_norm"]
)
layer.attention.output.LayerNorm.bias.data.copy_(
bias["self_attn_layer_norm"]
)
layer.intermediate.dense.weight.data.copy_(weight["fc1"])
layer.intermediate.dense.bias.data.copy_(bias["fc1"])
layer.output.dense.weight.data.copy_(weight["fc2"])
layer.output.dense.bias.data.copy_(bias["fc2"])
layer.output.LayerNorm.weight.data.copy_(weight["final_layer_norm"])
layer.output.LayerNorm.bias.data.copy_(bias["final_layer_norm"])

model_to_save = unwrap_model(model)
if not isinstance(model_to_save, LSBertPreTrainedModel):
raise ValueError("Must be ligtseq replaced model")
# reload original modules
ls_encoder_layer = model_to_save.bert.encoder.layer
model_to_save.bert.encoder.layer = nn.ModuleList(
[BertLayer(model.config) for _ in range(model.config.num_hidden_layers)]
)
inject_hf_layer(
model_to_save.config, model_to_save.bert.encoder.layer, ls_encoder_layer
)
state_dict = model_to_save.state_dict()
# replace with lightseq modules
model_to_save.bert.encoder.layer = ls_encoder_layer
return state_dict


class LSBertPreTrainedModel(BertPreTrainedModel):
@classmethod
def from_pretrained(self, *args, training_args, model_args, **kwargs):
self.config = kwargs["config"]
model = super().from_pretrained(*args, **kwargs)
if model_args.module_type == 1 or model_args.module_type == 2:
inject_ls_layer(model, training_args, model_args, self.config)
return model

def save_pretrained(self, *args, **kwargs):
kwargs["state_dict"] = hf_state_dict(self)
super().save_pretrained(*args, **kwargs)


class LSBertForSequenceClassification(
LSBertPreTrainedModel, BertForSequenceClassification
):
"""from BertForSequenceClassification"""


class LSBertLMHeadModel(LSBertPreTrainedModel, BertLMHeadModel):
"""from BertLMHeadModel"""


class LSBertForMaskedLM(LSBertPreTrainedModel, BertForMaskedLM):
"""from BertForMaskedLM"""


class LSBertForNextSentencePrediction(
LSBertPreTrainedModel, BertForNextSentencePrediction
):
"""from BertForNextSentencePrediction"""


class LSBertForMultipleChoice(LSBertPreTrainedModel, BertForMultipleChoice):
"""from BertForMultipleChoice"""


class LSBertForTokenClassification(LSBertPreTrainedModel, BertForTokenClassification):
"""from BertForTokenClassification"""


class LSBertForQuestionAnswering(LSBertPreTrainedModel, BertForQuestionAnswering):
"""from BertForQuestionAnswering"""
27 changes: 21 additions & 6 deletions examples/training/huggingface/bert/task_glue/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@
TrainingArguments,
default_data_collator,
set_seed,
BertForSequenceClassification,
BertLayer,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from ls_hf_transformer_layer import inject_ls_layer
from ls_hf_transformer_layer import inject_ls_layer, LSBertForSequenceClassification


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -406,19 +408,32 @@ def main():
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForSequenceClassification.from_pretrained(

# # Replace with lightseq encoder layers and save the lightseq model
# model = AutoModelForSequenceClassification.from_pretrained(
# model_args.model_name_or_path,
# from_tf=bool(".ckpt" in model_args.model_name_or_path),
# config=config,
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
# )
# # Replace with LightSeq encoder layers.
# if model_args.module_type == 1 or model_args.module_type == 2:
# inject_ls_layer(model, training_args, model_args, config)

# Replace with lightseq encoder layers and save the huggingface model
model = LSBertForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
training_args=training_args,
model_args=model_args,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)

# Replace with LightSeq encoder layers.
if model_args.module_type == 1 or model_args.module_type == 2:
inject_ls_layer(model, training_args, model_args, config)

# Preprocessing the datasets
if data_args.task_name is not None:
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
Expand Down
2 changes: 1 addition & 1 deletion examples/training/huggingface/bert/task_glue/run_glue.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ python3 -m torch.distributed.launch \
--fp16 \
--seed 1234 \
--logging_steps 10 \
--module_type 2 \
--module_type 1 \
--enable_quant false
25 changes: 19 additions & 6 deletions examples/training/huggingface/bert/task_ner/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from ls_hf_transformer_layer import inject_ls_layer
from ls_hf_transformer_layer import inject_ls_layer, LSBertForTokenClassification


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -366,19 +366,32 @@ def get_label_list(labels):
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForTokenClassification.from_pretrained(
# # Replace with lightseq encoder layers and save the lightseq model
# model = AutoModelForTokenClassification.from_pretrained(
# model_args.model_name_or_path,
# from_tf=bool(".ckpt" in model_args.model_name_or_path),
# config=config,
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
# )

# # Replace with LightSeq encoder layers.
# if model_args.module_type == 1 or model_args.module_type == 2:
# inject_ls_layer(model, training_args, model_args, config)

# Replace with lightseq encoder layers and save the huggingface model
model = LSBertForTokenClassification.from_pretrained(
model_args.model_name_or_path,
training_args=training_args,
model_args=model_args,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)

# Replace with LightSeq encoder layers.
if model_args.module_type == 1 or model_args.module_type == 2:
inject_ls_layer(model, training_args, model_args, config)

# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
Expand Down
24 changes: 18 additions & 6 deletions examples/training/huggingface/bert/task_qa/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions
from ls_hf_transformer_layer import inject_ls_layer
from ls_hf_transformer_layer import inject_ls_layer, LSBertForQuestionAnswering


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -373,19 +373,31 @@ def main():
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForQuestionAnswering.from_pretrained(
# # Replace with lightseq encoder layers and save the lightseq model
# model = AutoModelForQuestionAnswering.from_pretrained(
# model_args.model_name_or_path,
# from_tf=bool(".ckpt" in model_args.model_name_or_path),
# config=config,
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
# )
# # Replace with LightSeq encoder layers.
# if model_args.module_type == 1 or model_args.module_type == 2:
# inject_ls_layer(model, training_args, model_args, config)

# Replace with lightseq encoder layers and save the huggingface model
model = LSBertForQuestionAnswering.from_pretrained(
model_args.model_name_or_path,
training_args=training_args,
model_args=model_args,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)

# Replace with LightSeq encoder layers.
if model_args.module_type == 1 or model_args.module_type == 2:
inject_ls_layer(model, training_args, model_args, config)

# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
Expand Down
Loading