diff --git a/paddlenlp/transformers/bart/converter.py b/paddlenlp/transformers/bart/converter.py deleted file mode 100644 index f921236b7d31..000000000000 --- a/paddlenlp/transformers/bart/converter.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from collections import OrderedDict - -import numpy as np -import paddle -import torch -from transformers import BartForConditionalGeneration as hf_BartForConditionalGeneration - -from paddlenlp.transformers import ( - BartForConditionalGeneration as pp_BartForConditionalGeneration, -) -from paddlenlp.utils import load_torch -from paddlenlp.utils.downloader import get_path_from_url_with_filelock -from paddlenlp.utils.log import logger - -# Download huggingface models -hf_hub_repo = "fnlp/bart-base-chinese" -base_url = f"https://huggingface.co/{hf_hub_repo}/resolve/main/" - -pp_hf_checkpoint = hf_hub_repo.replace("/", "_") -os.makedirs(pp_hf_checkpoint, exist_ok=True) - -for i in [ - "config.json", - "vocab.txt", - "tokenizer_config.json", - "special_tokens_map.json", - "pytorch_model.bin", - "added_tokens.json", - "spiece.model", -]: - try: - get_path_from_url_with_filelock(f"{base_url}{i}", pp_hf_checkpoint) - except RuntimeError: - logger.warning(f"{base_url}{i} not found.") - -use_torch = False -try: - hf_model = load_torch(os.path.join(pp_hf_checkpoint, "pytorch_model.bin")) -except ValueError: - # Some models coming from pytorch_lighting - use_torch = True - hf_model = torch.load(os.path.join(pp_hf_checkpoint, "pytorch_model.bin"), map_location="cpu") - -huggingface_to_paddle_encoder = { - "model.encoder.embed_tokens": "bart.encoder.embed_tokens", - "model.encoder.embed_positions": "bart.encoder.encoder_embed_positions", - "model.encoder.layernorm_embedding": "bart.encoder.encoder_layernorm_embedding", - ".self_attn_layer_norm.": ".norm1.", - ".fc1.": ".linear1.", - ".fc2.": ".linear2.", - ".final_layer_norm.": ".norm2.", - "model.encoder": "bart.encoder.encoder", -} - -huggingface_to_paddle_decoder = { - "model.decoder.embed_tokens": "bart.decoder.embed_tokens", - "model.decoder.embed_positions": "bart.decoder.decoder_embed_positions", - "model.decoder.layernorm_embedding": "bart.decoder.decoder_layernorm_embedding", - ".self_attn_layer_norm.": ".norm1.", - ".encoder_attn.": ".cross_attn.", - ".encoder_attn_layer_norm.": ".norm2.", - ".fc1.": ".linear1.", - ".fc2.": ".linear2.", - ".final_layer_norm.": ".norm3.", - "model.decoder": "bart.decoder.decoder", -} - -skip_weights = [] - -dont_transpose = [ - ".embed_positions.weight", - ".embed_tokens.weight", - "layernorm_embedding.weight", - "norm.weight", - ".shared.weight", - "lm_head.weight", -] - -paddle_state_dict = OrderedDict() - -# Convert parameters -for k, v in hf_model.items(): - transpose = False - if k in skip_weights: - continue - if k[-7:] == ".weight": - if not any([w in k for w in dont_transpose]): - if v.ndim == 2: - v = v.transpose(0, 1) if use_torch else v.transpose() - transpose = True - oldk = k - - if "model.encoder." in k: - for huggingface_name, paddle_name in huggingface_to_paddle_encoder.items(): - k = k.replace(huggingface_name, paddle_name) - elif "model.decoder." in k: - for huggingface_name, paddle_name in huggingface_to_paddle_decoder.items(): - k = k.replace(huggingface_name, paddle_name) - - if oldk == "model.shared.weight": - k = "bart.shared.weight" - - if oldk == "lm_head.weight": - k = "lm_head_weight" - - logger.info(f"Converting: {oldk} => {k} | is_transpose {transpose}") - - paddle_state_dict[k] = v.data.numpy() if use_torch else v - -# Save to .pdparams -paddle.save(paddle_state_dict, os.path.join(pp_hf_checkpoint, "model_state.pdparams")) - -# Compare ppnlp with hf -paddle.set_grad_enabled(False) -torch.set_grad_enabled(False) -pp_model = pp_BartForConditionalGeneration.from_pretrained(pp_hf_checkpoint) -pp_model.eval() -hf_model = hf_BartForConditionalGeneration.from_pretrained(pp_hf_checkpoint) -hf_model.eval() - -input_ids = np.random.randint(1, 10000, size=(2, 10)) -pp_inputs = paddle.to_tensor(input_ids) -hf_inputs = torch.tensor(input_ids) - -pp_output = pp_model(pp_inputs) -hf_output = hf_model(hf_inputs) - -diff = abs(hf_output.logits.detach().numpy() - pp_output.numpy()) -logger.info(f"max diff: {np.max(diff)}, min diff: {np.min(diff)}") diff --git a/paddlenlp/transformers/bart/modeling.py b/paddlenlp/transformers/bart/modeling.py index dac80d79866d..f216ea7d7216 100644 --- a/paddlenlp/transformers/bart/modeling.py +++ b/paddlenlp/transformers/bart/modeling.py @@ -23,6 +23,7 @@ from paddle import Tensor from paddle.nn import Embedding, Layer, MultiHeadAttention +from ...utils.converter import StateDictNameMapping from ...utils.env import CONFIG_NAME from ...utils.log import logger from .. import PretrainedModel, register_base_model @@ -82,6 +83,264 @@ class BartPretrainedModel(PretrainedModel): base_model_prefix = "bart" config_class = BartConfig + @classmethod + def _get_name_mappings(cls, config: BartConfig) -> List[StateDictNameMapping]: + model_mappings = [ + ["shared.weight", "shared.weight"], + ] + + num_encoder_layers = config.num_encoder_layers or 0 + num_decoder_layers = config.num_decoder_layers or 0 + + if num_encoder_layers: + encoder_mappings = [ + ["encoder.embed_positions.weight", "encoder.encoder_embed_positions.weight"], + ["encoder.layernorm_embedding.weight", "encoder.encoder_layernorm_embedding.weight"], + ["encoder.layernorm_embedding.bias", "encoder.encoder_layernorm_embedding.bias"], + ] + + model_mappings.extend(encoder_mappings) + + for layer_index in range(num_encoder_layers): + encoder_mappings = [ + [ + f"encoder.layers.{layer_index}.self_attn.k_proj.weight", + f"encoder.encoder.layers.{layer_index}.self_attn.k_proj.weight", + "transpose", + ], + [ + f"encoder.layers.{layer_index}.self_attn.k_proj.bias", + f"encoder.encoder.layers.{layer_index}.self_attn.k_proj.bias", + ], + [ + f"encoder.layers.{layer_index}.self_attn.v_proj.weight", + f"encoder.encoder.layers.{layer_index}.self_attn.v_proj.weight", + "transpose", + ], + [ + f"encoder.layers.{layer_index}.self_attn.v_proj.bias", + f"encoder.encoder.layers.{layer_index}.self_attn.v_proj.bias", + ], + [ + f"encoder.layers.{layer_index}.self_attn.q_proj.weight", + f"encoder.encoder.layers.{layer_index}.self_attn.q_proj.weight", + "transpose", + ], + [ + f"encoder.layers.{layer_index}.self_attn.q_proj.bias", + f"encoder.encoder.layers.{layer_index}.self_attn.q_proj.bias", + ], + [ + f"encoder.layers.{layer_index}.self_attn.out_proj.weight", + f"encoder.encoder.layers.{layer_index}.self_attn.out_proj.weight", + "transpose", + ], + [ + f"encoder.layers.{layer_index}.self_attn.out_proj.bias", + f"encoder.encoder.layers.{layer_index}.self_attn.out_proj.bias", + ], + [ + f"encoder.layers.{layer_index}.fc1.weight", + f"encoder.encoder.layers.{layer_index}.linear1.weight", + "transpose", + ], + [ + f"encoder.layers.{layer_index}.fc1.bias", + f"encoder.encoder.layers.{layer_index}.linear1.bias", + ], + [ + f"encoder.layers.{layer_index}.fc2.weight", + f"encoder.encoder.layers.{layer_index}.linear2.weight", + "transpose", + ], + [ + f"encoder.layers.{layer_index}.fc2.bias", + f"encoder.encoder.layers.{layer_index}.linear2.bias", + ], + [ + f"encoder.layers.{layer_index}.self_attn_layer_norm.weight", + f"encoder.encoder.layers.{layer_index}.norm1.weight", + ], + [ + f"encoder.layers.{layer_index}.self_attn_layer_norm.bias", + f"encoder.encoder.layers.{layer_index}.norm1.bias", + ], + [ + f"encoder.layers.{layer_index}.final_layer_norm.weight", + f"encoder.encoder.layers.{layer_index}.norm2.weight", + ], + [ + f"encoder.layers.{layer_index}.final_layer_norm.bias", + f"encoder.encoder.layers.{layer_index}.norm2.bias", + ], + ] + + model_mappings.extend(encoder_mappings) + + if num_decoder_layers: + decoder_mappings = [ + ["decoder.embed_positions.weight", "decoder.decoder_embed_positions.weight"], + ["decoder.layernorm_embedding.weight", "decoder.decoder_layernorm_embedding.weight"], + ["decoder.layernorm_embedding.bias", "decoder.decoder_layernorm_embedding.bias"], + ] + + model_mappings.extend(decoder_mappings) + + for layer_index in range(num_decoder_layers): + decoder_mappings = [ + [ + f"decoder.layers.{layer_index}.self_attn.k_proj.weight", + f"decoder.decoder.layers.{layer_index}.self_attn.k_proj.weight", + "transpose", + ], + [ + f"decoder.layers.{layer_index}.self_attn.k_proj.bias", + f"decoder.decoder.layers.{layer_index}.self_attn.k_proj.bias", + ], + [ + f"decoder.layers.{layer_index}.self_attn.v_proj.weight", + f"decoder.decoder.layers.{layer_index}.self_attn.v_proj.weight", + "transpose", + ], + [ + f"decoder.layers.{layer_index}.self_attn.v_proj.bias", + f"decoder.decoder.layers.{layer_index}.self_attn.v_proj.bias", + ], + [ + f"decoder.layers.{layer_index}.self_attn.q_proj.weight", + f"decoder.decoder.layers.{layer_index}.self_attn.q_proj.weight", + "transpose", + ], + [ + f"decoder.layers.{layer_index}.self_attn.q_proj.bias", + f"decoder.decoder.layers.{layer_index}.self_attn.q_proj.bias", + ], + [ + f"decoder.layers.{layer_index}.self_attn.out_proj.weight", + f"decoder.decoder.layers.{layer_index}.self_attn.out_proj.weight", + "transpose", + ], + [ + f"decoder.layers.{layer_index}.self_attn.out_proj.bias", + f"decoder.decoder.layers.{layer_index}.self_attn.out_proj.bias", + ], + [ + f"decoder.layers.{layer_index}.encoder_attn.k_proj.weight", + f"decoder.decoder.layers.{layer_index}.cross_attn.k_proj.weight", + "transpose", + ], + [ + f"decoder.layers.{layer_index}.encoder_attn.k_proj.bias", + f"decoder.decoder.layers.{layer_index}.cross_attn.k_proj.bias", + ], + [ + f"decoder.layers.{layer_index}.encoder_attn.v_proj.weight", + f"decoder.decoder.layers.{layer_index}.cross_attn.v_proj.weight", + "transpose", + ], + [ + f"decoder.layers.{layer_index}.encoder_attn.v_proj.bias", + f"decoder.decoder.layers.{layer_index}.cross_attn.v_proj.bias", + ], + [ + f"decoder.layers.{layer_index}.encoder_attn.q_proj.weight", + f"decoder.decoder.layers.{layer_index}.cross_attn.q_proj.weight", + "transpose", + ], + [ + f"decoder.layers.{layer_index}.encoder_attn.q_proj.bias", + f"decoder.decoder.layers.{layer_index}.cross_attn.q_proj.bias", + ], + [ + f"decoder.layers.{layer_index}.encoder_attn.out_proj.weight", + f"decoder.decoder.layers.{layer_index}.cross_attn.out_proj.weight", + "transpose", + ], + [ + f"decoder.layers.{layer_index}.encoder_attn.out_proj.bias", + f"decoder.decoder.layers.{layer_index}.cross_attn.out_proj.bias", + ], + [ + f"decoder.layers.{layer_index}.fc1.weight", + f"decoder.decoder.layers.{layer_index}.linear1.weight", + "transpose", + ], + [ + f"decoder.layers.{layer_index}.fc1.bias", + f"decoder.decoder.layers.{layer_index}.linear1.bias", + ], + [ + f"decoder.layers.{layer_index}.fc2.weight", + f"decoder.decoder.layers.{layer_index}.linear2.weight", + "transpose", + ], + [ + f"decoder.layers.{layer_index}.fc2.bias", + f"decoder.decoder.layers.{layer_index}.linear2.bias", + ], + [ + f"decoder.layers.{layer_index}.self_attn_layer_norm.weight", + f"decoder.decoder.layers.{layer_index}.norm1.weight", + ], + [ + f"decoder.layers.{layer_index}.self_attn_layer_norm.bias", + f"decoder.decoder.layers.{layer_index}.norm1.bias", + ], + [ + f"decoder.layers.{layer_index}.encoder_attn_layer_norm.weight", + f"decoder.decoder.layers.{layer_index}.norm2.weight", + ], + [ + f"decoder.layers.{layer_index}.encoder_attn_layer_norm.bias", + f"decoder.decoder.layers.{layer_index}.norm2.bias", + ], + [ + f"decoder.layers.{layer_index}.final_layer_norm.weight", + f"decoder.decoder.layers.{layer_index}.norm3.weight", + ], + [ + f"decoder.layers.{layer_index}.final_layer_norm.bias", + f"decoder.decoder.layers.{layer_index}.norm3.bias", + ], + ] + + model_mappings.extend(decoder_mappings) + + # base-model prefix "BartModel" + if "BartModel" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "bart." + mapping[1] + + if "BartForQuestionAnswering" in config.architectures: + model_mappings.extend( + [ + ["qa_outputs.weight", "classifier.weight", "transpose"], + ["qa_outputs.bias", "classifier.bias"], + ] + ) + + if "BartForSequenceClassification" in config.architectures: + model_mappings.extend( + [ + ["classification_head.dense.weight", "classifier.dense.weight", "transpose"], + ["classification_head.dense.bias", "classifier.dense.bias"], + ["classification_head.out_proj.weight", "classifier.out_proj.weight", "transpose"], + ["classification_head.out_proj.bias", "classifier.out_proj.bias"], + ] + ) + + if "BartForConditionalGeneration" in config.architectures: + model_mappings.extend( + [ + ["lm_head.weight", "lm_head_weight"], + ["final_logits_bias", "final_logits_bias"], + ] + ) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + def init_weights(self, layer): """Initialization hook""" if isinstance(layer, (nn.Linear, nn.Embedding)): diff --git a/tests/transformers/bart/test_modeling.py b/tests/transformers/bart/test_modeling.py index 179a0e4eb9b8..7a051cca8089 100644 --- a/tests/transformers/bart/test_modeling.py +++ b/tests/transformers/bart/test_modeling.py @@ -14,11 +14,12 @@ # limitations under the License. import random +import tempfile import unittest import numpy as np import paddle -from parameterized import parameterized_class +from parameterized import parameterized, parameterized_class from paddlenlp.transformers import ( BartConfig, @@ -33,7 +34,7 @@ PaddingStrategy, TruncationStrategy, ) -from tests.testing_utils import slow +from tests.testing_utils import require_package, slow from ..test_generation_utils import GenerationTesterMixin from ..test_modeling_common import ModelTesterMixin, ids_tensor @@ -889,3 +890,132 @@ def test_cnn_summarization_same_as_fairseq(self): tok.batch_decode( hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True ) # assigned generated_summaries but never used + + +class BartModelCompatibilityTest(unittest.TestCase): + model_id = "hf-internal-testing/tiny-random-BartModel" + + @require_package("transformers", "torch") + def test_bart_converter(self): + with tempfile.TemporaryDirectory() as tempdir: + # 1. create input + input_ids = np.random.randint(100, 200, [1, 20]) + + # 2. forward the paddle model + from paddlenlp.transformers import BartModel + + paddle_model = BartModel.from_pretrained(self.model_id, from_hf_hub=True, cache_dir=tempdir) + paddle_model.eval() + paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0] + + # 3. forward the torch model + import torch + from transformers import BartModel + + torch_model = BartModel.from_pretrained(self.model_id, cache_dir=tempdir) + torch_model.eval() + torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0] + + # 4. compare results + self.assertTrue( + np.allclose( + paddle_logit.detach().cpu().reshape([-1])[:9].numpy(), + torch_logit.detach().cpu().reshape([-1])[:9].numpy(), + rtol=1e-4, + ) + ) + + @require_package("transformers", "torch") + def test_bart_converter_from_local_dir_with_enable_torch(self): + with tempfile.TemporaryDirectory() as tempdir: + # 1. forward the torch model + from transformers import BartModel + + torch_model = BartModel.from_pretrained(self.model_id) + torch_model.save_pretrained(tempdir) + + # 2. forward the paddle model + from paddlenlp.transformers import BartModel, model_utils + + model_utils.ENABLE_TORCH_CHECKPOINT = False + + with self.assertRaises(ValueError) as error: + BartModel.from_pretrained(tempdir) + self.assertIn("conversion is been disabled" in str(error.exception)) + model_utils.ENABLE_TORCH_CHECKPOINT = True + + @require_package("transformers", "torch") + def test_bart_converter_from_local_dir(self): + with tempfile.TemporaryDirectory() as tempdir: + + # 1. create commmon input + input_ids = np.random.randint(100, 200, [1, 20]) + + # 2. forward the torch model + import torch + from transformers import BartModel + + torch_model = BartModel.from_pretrained(self.model_id) + torch_model.eval() + torch_model.save_pretrained(tempdir) + torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0] + + # 2. forward the paddle model + from paddlenlp.transformers import BartModel + + paddle_model = BartModel.from_pretrained(tempdir) + paddle_model.eval() + paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0] + + self.assertTrue( + np.allclose( + paddle_logit.detach().cpu().reshape([-1])[:9].numpy(), + torch_logit.detach().cpu().reshape([-1])[:9].numpy(), + rtol=1e-4, + ) + ) + + @parameterized.expand( + [ + ("BartModel",), + ("BartForSequenceClassification",), + ("BartForQuestionAnswering",), + ("BartForConditionalGeneration",), + ] + ) + @require_package("transformers", "torch") + def test_bart_classes_from_local_dir(self, class_name, pytorch_class_name=None): + pytorch_class_name = pytorch_class_name or class_name + with tempfile.TemporaryDirectory() as tempdir: + + # 1. create commmon input + input_ids = np.random.randint(100, 200, [1, 20]) + # wrap `input_ids`, because `transformers.BartForSequenceClassification` need `eos_mask` + input_ids = [[0] + input_ids[0].tolist() + [2]] + + # 2. forward the torch model + import torch + import transformers + + torch_model_class = getattr(transformers, pytorch_class_name) + torch_model = torch_model_class.from_pretrained(self.model_id) + torch_model.eval() + torch_model.save_pretrained(tempdir) + torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0] + + # 3. forward the paddle model + from paddlenlp import transformers + + paddle_model_class = getattr(transformers, class_name) + paddle_model = paddle_model_class.from_pretrained(tempdir) + paddle_model.eval() + + paddle_logit = paddle_model(paddle.to_tensor(input_ids), return_dict=False)[0] + + self.assertTrue( + np.allclose( + paddle_logit.detach().cpu().reshape([-1])[:9].numpy(), + torch_logit.detach().cpu().reshape([-1])[:9].numpy(), + atol=1e-3, + ) + )