From 7a6baa71ee3e3744dda7e1758924d4c31d44d907 Mon Sep 17 00:00:00 2001 From: megemini Date: Fri, 14 Apr 2023 17:34:19 +0800 Subject: [PATCH] [Add]Add DistilBert to AutoConverter (#5672) --- .../transformers/distilbert/configuration.py | 9 +- paddlenlp/transformers/distilbert/modeling.py | 120 ++++++++++++++++ .../transformers/distilbert/test_modeling.py | 132 +++++++++++++++++- 3 files changed, 259 insertions(+), 2 deletions(-) diff --git a/paddlenlp/transformers/distilbert/configuration.py b/paddlenlp/transformers/distilbert/configuration.py index 55008b51e422..724af5911ca8 100644 --- a/paddlenlp/transformers/distilbert/configuration.py +++ b/paddlenlp/transformers/distilbert/configuration.py @@ -122,7 +122,14 @@ class DistilBertConfig(PretrainedConfig): >>> configuration = model.config ```""" model_type = "distilbert" - attribute_map: Dict[str, str] = {"dropout": "classifier_dropout", "num_classes": "num_labels"} + attribute_map: Dict[str, str] = { + "dropout": "classifier_dropout", + "num_classes": "num_labels", + "n_layers": "num_hidden_layers", # for `transformers` + "n_heads": "num_attention_heads", # for `transformers` + "dim": "hidden_size", # for `transformers` + "hidden_dim": "intermediate_size", # for `transformers` + } pretrained_init_configuration = DISTILBERT_PRETRAINED_INIT_CONFIGURATION def __init__( diff --git a/paddlenlp/transformers/distilbert/modeling.py b/paddlenlp/transformers/distilbert/modeling.py index fe2723262d22..a3e1199bfa0d 100644 --- a/paddlenlp/transformers/distilbert/modeling.py +++ b/paddlenlp/transformers/distilbert/modeling.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + import paddle import paddle.nn as nn from paddlenlp.utils.env import CONFIG_NAME +from ...utils.converter import StateDictNameMapping from .. import PretrainedModel, register_base_model from .configuration import ( DISTILBERT_PRETRAINED_INIT_CONFIGURATION, @@ -79,6 +82,123 @@ class DistilBertPretrainedModel(PretrainedModel): config_class = DistilBertConfig model_config_file = CONFIG_NAME + @classmethod + def _get_name_mappings(cls, config: DistilBertConfig) -> List[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"], + ["embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"], + ["embeddings.LayerNorm.weight", "embeddings.layer_norm.weight"], + ["embeddings.LayerNorm.bias", "embeddings.layer_norm.bias"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [ + f"transformer.layer.{layer_index}.attention.q_lin.weight", + f"encoder.layers.{layer_index}.self_attn.q_proj.weight", + "transpose", + ], + [ + f"transformer.layer.{layer_index}.attention.q_lin.bias", + f"encoder.layers.{layer_index}.self_attn.q_proj.bias", + ], + [ + f"transformer.layer.{layer_index}.attention.k_lin.weight", + f"encoder.layers.{layer_index}.self_attn.k_proj.weight", + "transpose", + ], + [ + f"transformer.layer.{layer_index}.attention.k_lin.bias", + f"encoder.layers.{layer_index}.self_attn.k_proj.bias", + ], + [ + f"transformer.layer.{layer_index}.attention.v_lin.weight", + f"encoder.layers.{layer_index}.self_attn.v_proj.weight", + "transpose", + ], + [ + f"transformer.layer.{layer_index}.attention.v_lin.bias", + f"encoder.layers.{layer_index}.self_attn.v_proj.bias", + ], + [ + f"transformer.layer.{layer_index}.attention.out_lin.weight", + f"encoder.layers.{layer_index}.self_attn.out_proj.weight", + "transpose", + ], + [ + f"transformer.layer.{layer_index}.attention.out_lin.bias", + f"encoder.layers.{layer_index}.self_attn.out_proj.bias", + ], + [ + f"transformer.layer.{layer_index}.sa_layer_norm.weight", + f"encoder.layers.{layer_index}.norm1.weight", + ], + [ + f"transformer.layer.{layer_index}.sa_layer_norm.bias", + f"encoder.layers.{layer_index}.norm1.bias", + ], + [ + f"transformer.layer.{layer_index}.output_layer_norm.weight", + f"encoder.layers.{layer_index}.norm2.weight", + ], + [ + f"transformer.layer.{layer_index}.output_layer_norm.bias", + f"encoder.layers.{layer_index}.norm2.bias", + ], + [ + f"transformer.layer.{layer_index}.ffn.lin1.weight", + f"encoder.layers.{layer_index}.linear1.weight", + "transpose", + ], + [ + f"transformer.layer.{layer_index}.ffn.lin1.bias", + f"encoder.layers.{layer_index}.linear1.bias", + ], + [ + f"transformer.layer.{layer_index}.ffn.lin2.weight", + f"encoder.layers.{layer_index}.linear2.weight", + "transpose", + ], + [ + f"transformer.layer.{layer_index}.ffn.lin2.bias", + f"encoder.layers.{layer_index}.linear2.bias", + ], + ] + model_mappings.extend(layer_mappings) + + # base-model prefix "DistilBertModel" + if "DistilBertModel" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "distilbert." + mapping[0] + mapping[1] = "distilbert." + mapping[1] + + # downstream mappings + if "DistilBertForSequenceClassification" in config.architectures: + model_mappings.extend( + [ + ["pre_classifier.weight", "pre_classifier.weight", "transpose"], + ["pre_classifier.bias", "pre_classifier.bias"], + ["classifier.weight", "classifier.weight", "transpose"], + ["classifier.bias", "classifier.bias"], + ] + ) + + if "DistilBertForTokenClassification" in config.architectures: + model_mappings.extend( + [ + ["classifier.weight", "classifier.weight", "transpose"], + ["classifier.bias", "classifier.bias"], + ] + ) + + if "DistilBertForQuestionAnswering" in config.architectures: + model_mappings.extend( + [["qa_outputs.weight", "classifier.weight", "transpose"], ["qa_outputs.bias", "classifier.bias"]] + ) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + def init_weights(self, layer): """Initialization hook""" if isinstance(layer, (nn.Linear, nn.Embedding)): diff --git a/tests/transformers/distilbert/test_modeling.py b/tests/transformers/distilbert/test_modeling.py index 366a7ec74025..e92d7744fcae 100644 --- a/tests/transformers/distilbert/test_modeling.py +++ b/tests/transformers/distilbert/test_modeling.py @@ -14,9 +14,12 @@ # limitations under the License. from __future__ import annotations +import tempfile import unittest +import numpy as np import paddle +from parameterized import parameterized from paddlenlp.transformers import ( DistilBertForMaskedLM, @@ -27,7 +30,7 @@ ) from paddlenlp.transformers.distilbert.configuration import DistilBertConfig -from ...testing_utils import slow +from ...testing_utils import require_package, slow from ..test_configuration_common import ConfigTester from ..test_modeling_common import ( ModelTesterMixin, @@ -288,6 +291,133 @@ def test_params_compatibility_of_init_method(self): assert model.dropout.p == 0.3 +class DistilBertModelCompatibilityTest(unittest.TestCase): + model_id = "hf-internal-testing/tiny-random-DistilBertModel" + + @require_package("transformers", "torch") + def test_distilBert_converter(self): + with tempfile.TemporaryDirectory() as tempdir: + # 1. create input + input_ids = np.random.randint(100, 200, [1, 20]) + + # 2. forward the paddle model + from paddlenlp.transformers import DistilBertModel + + paddle_model = DistilBertModel.from_pretrained(self.model_id, from_hf_hub=True, cache_dir=tempdir) + paddle_model.eval() + paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0] + + # 3. forward the torch model + import torch + from transformers import DistilBertModel + + torch_model = DistilBertModel.from_pretrained(self.model_id, cache_dir=tempdir) + torch_model.eval() + torch_logit = torch_model(torch.tensor(input_ids))[0] + + # 4. compare results + self.assertTrue( + np.allclose( + paddle_logit.detach().cpu().reshape([-1])[:9].numpy(), + torch_logit.detach().cpu().reshape([-1])[:9].numpy(), + rtol=1e-4, + ) + ) + + @require_package("transformers", "torch") + def test_distilBert_converter_from_local_dir_with_enable_torch(self): + with tempfile.TemporaryDirectory() as tempdir: + # 1. forward the torch model + from transformers import DistilBertModel + + torch_model = DistilBertModel.from_pretrained(self.model_id) + torch_model.save_pretrained(tempdir) + + # 2. forward the paddle model + from paddlenlp.transformers import DistilBertModel, model_utils + + model_utils.ENABLE_TORCH_CHECKPOINT = False + + with self.assertRaises(ValueError) as error: + DistilBertModel.from_pretrained(tempdir) + self.assertIn("conversion is been disabled" in str(error.exception)) + model_utils.ENABLE_TORCH_CHECKPOINT = True + + @require_package("transformers", "torch") + def test_distilBert_converter_from_local_dir(self): + with tempfile.TemporaryDirectory() as tempdir: + + # 1. create commmon input + input_ids = np.random.randint(100, 200, [1, 20]) + + # 2. forward the torch model + import torch + from transformers import DistilBertModel + + torch_model = DistilBertModel.from_pretrained(self.model_id) + torch_model.eval() + torch_model.save_pretrained(tempdir) + torch_logit = torch_model(torch.tensor(input_ids))[0] + + # 2. forward the paddle model + from paddlenlp.transformers import DistilBertModel + + paddle_model = DistilBertModel.from_pretrained(tempdir) + paddle_model.eval() + paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0] + + self.assertTrue( + np.allclose( + paddle_logit.detach().cpu().reshape([-1])[:9].numpy(), + torch_logit.detach().cpu().reshape([-1])[:9].numpy(), + rtol=1e-4, + ) + ) + + @parameterized.expand( + [ + ("DistilBertModel",), + ("DistilBertForQuestionAnswering",), + ("DistilBertForSequenceClassification",), + ("DistilBertForTokenClassification",), + ] + ) + @require_package("transformers", "torch") + def test_distilBert_classes_from_local_dir(self, class_name, pytorch_class_name=None): + pytorch_class_name = pytorch_class_name or class_name + with tempfile.TemporaryDirectory() as tempdir: + + # 1. create commmon input + input_ids = np.random.randint(100, 200, [1, 20]) + + # 2. forward the torch model + import torch + import transformers + + torch_model_class = getattr(transformers, pytorch_class_name) + torch_model = torch_model_class.from_pretrained(self.model_id) + torch_model.eval() + torch_model.save_pretrained(tempdir) + torch_logit = torch_model(torch.tensor(input_ids))[0] + + # 3. forward the paddle model + from paddlenlp import transformers + + paddle_model_class = getattr(transformers, class_name) + paddle_model = paddle_model_class.from_pretrained(tempdir) + paddle_model.eval() + + paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0] + + self.assertTrue( + np.allclose( + paddle_logit.detach().cpu().reshape([-1])[:9].numpy(), + torch_logit.detach().cpu().reshape([-1])[:9].numpy(), + atol=1e-3, + ) + ) + + class DistilBertModelIntegrationTest(ModelTesterPretrainedMixin, unittest.TestCase): base_model_class = DistilBertModel