Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Megatron-11B #10301

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
],
"models": [],
# Models
"models.megatron": ["MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronConfig", "MegatronTokenizer"],
"models.wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config", "Wav2Vec2Tokenizer"],
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
Expand Down Expand Up @@ -279,6 +280,7 @@
# tokenziers-backed objects
if is_tokenizers_available():
# Fast tokenizers
_import_structure["models.megatron"].append("MegatronTokenizerFast")
_import_structure["models.convbert"].append("ConvBertTokenizerFast")
_import_structure["models.albert"].append("AlbertTokenizerFast")
_import_structure["models.bart"].append("BartTokenizerFast")
Expand Down Expand Up @@ -367,6 +369,15 @@
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
# PyTorch models structure

_import_structure["models.megatron"].extend(
[
"MEGATRON_PRETRAINED_MODEL_ARCHIVE_LIST",
"MegatronForCausalLM",
"MegatronForSequenceClassification",
"MegatronModel",
]
)

_import_structure["models.wav2vec2"].extend(
[
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -1322,6 +1333,7 @@
from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer
from .models.marian import MarianConfig
from .models.mbart import MBartConfig
from .models.megatron import MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronConfig, MegatronTokenizer
from .models.mmbt import MMBTConfig
from .models.mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig, MobileBertTokenizer
from .models.mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig, MPNetTokenizer
Expand Down Expand Up @@ -1435,6 +1447,7 @@
from .models.longformer import LongformerTokenizerFast
from .models.lxmert import LxmertTokenizerFast
from .models.mbart import MBart50TokenizerFast, MBartTokenizerFast
from .models.megatron import MegatronTokenizerFast
from .models.mobilebert import MobileBertTokenizerFast
from .models.mpnet import MPNetTokenizerFast
from .models.mt5 import MT5TokenizerFast
Expand Down Expand Up @@ -1733,6 +1746,12 @@
MBartForSequenceClassification,
MBartModel,
)
from .models.megatron import (
MEGATRON_PRETRAINED_MODEL_ARCHIVE_LIST,
MegatronForCausalLM,
MegatronForSequenceClassification,
MegatronModel,
)
from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
from .models.mobilebert import (
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
lxmert,
marian,
mbart,
megatron,
mmbt,
mobilebert,
mpnet,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
from ..marian.configuration_marian import MarianConfig
from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
from ..megatron.configuration_megatron import MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronConfig
from ..mobilebert.configuration_mobilebert import MobileBertConfig
from ..mpnet.configuration_mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig
from ..mt5.configuration_mt5 import MT5Config
Expand Down Expand Up @@ -74,6 +75,7 @@
(key, value)
for pretrained_map in [
# Add archive maps here
MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LED_PRETRAINED_CONFIG_ARCHIVE_MAP,
Expand Down Expand Up @@ -118,6 +120,7 @@
CONFIG_MAPPING = OrderedDict(
[
# Add configs here
("megatron", MegatronConfig),
("wav2vec2", Wav2Vec2Config),
("convbert", ConvBertConfig),
("led", LEDConfig),
Expand Down Expand Up @@ -168,6 +171,7 @@
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
("megatron", "Megatron"),
("wav2vec2", "Wav2Vec2"),
("convbert", "ConvBERT"),
("led", "LED"),
Expand Down
16 changes: 12 additions & 4 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings
from ...utils import logging

# Add modeling imports here
from ..albert.modeling_albert import (
AlbertForMaskedLM,
AlbertForMultipleChoice,
Expand Down Expand Up @@ -66,8 +64,6 @@
CamembertForTokenClassification,
CamembertModel,
)

# Add modeling imports here
from ..convbert.modeling_convbert import (
ConvBertForMaskedLM,
ConvBertForMultipleChoice,
Expand Down Expand Up @@ -158,6 +154,14 @@
MBartForSequenceClassification,
MBartModel,
)

# Add modeling imports here
# Add modeling imports here
from ..megatron.modeling_megatron import (
MegatronForCausalLM,
MegatronForSequenceClassification,
MegatronModel,
)
from ..mobilebert.modeling_mobilebert import (
MobileBertForMaskedLM,
MobileBertForMultipleChoice,
Expand Down Expand Up @@ -276,6 +280,7 @@
LxmertConfig,
MarianConfig,
MBartConfig,
MegatronConfig,
MobileBertConfig,
MPNetConfig,
MT5Config,
Expand Down Expand Up @@ -304,6 +309,7 @@
MODEL_MAPPING = OrderedDict(
[
# Base model mapping
(MegatronConfig, MegatronModel),
(Wav2Vec2Config, Wav2Vec2Model),
(ConvBertConfig, ConvBertModel),
(LEDConfig, LEDModel),
Expand Down Expand Up @@ -424,6 +430,7 @@
MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Causal LM mapping
(MegatronConfig, MegatronForCausalLM),
(CamembertConfig, CamembertForCausalLM),
(XLMRobertaConfig, XLMRobertaForCausalLM),
(RobertaConfig, RobertaForCausalLM),
Expand Down Expand Up @@ -501,6 +508,7 @@
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Sequence Classification mapping
(MegatronConfig, MegatronForSequenceClassification),
(ConvBertConfig, ConvBertForSequenceClassification),
(LEDConfig, LEDForSequenceClassification),
(DistilBertConfig, DistilBertForSequenceClassification),
Expand Down
74 changes: 74 additions & 0 deletions src/transformers/models/megatron/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available


_import_structure = {
"configuration_megatron": ["MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronConfig"],
"tokenization_megatron": ["MegatronTokenizer"],
}

if is_tokenizers_available():
_import_structure["tokenization_megatron_fast"] = ["MegatronTokenizerFast"]

if is_torch_available():
_import_structure["modeling_megatron"] = [
"MEGATRON_PRETRAINED_MODEL_ARCHIVE_LIST",
"MegatronForCausalLM",
"MegatronForSequenceClassification",
"MegatronModel",
"MegatronPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_megatron import MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronConfig
from .tokenization_megatron import MegatronTokenizer

if is_tokenizers_available():
from .tokenization_megatron_fast import MegatronTokenizerFast

if is_torch_available():
from .modeling_megatron import (
MEGATRON_PRETRAINED_MODEL_ARCHIVE_LIST,
MegatronForCausalLM,
MegatronForSequenceClassification,
MegatronModel,
MegatronPreTrainedModel,
)


else:
import importlib
import os
import sys

class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""

__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]

def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)

sys.modules[__name__] = _LazyModule(__name__, _import_structure)
150 changes: 150 additions & 0 deletions src/transformers/models/megatron/configuration_megatron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# coding=utf-8
# Copyright The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Megatron model configuration """

from ...configuration_utils import PretrainedConfig
from ...utils import logging


logger = logging.get_logger(__name__)

MEGATRON_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"megatron-11b": "https://huggingface.co/anton-l/megatron-11b/resolve/main/config.json",
# See all Megatron models at https://huggingface.co/models?filter=megatron
}


class MegatronConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.MegatronModel`. It is used to
instantiate an Megatron model according to the specified arguments, defining the model architecture. Instantiating
a configuration with the defaults will yield a similar configuration to that of the Megatron `megatron-11b
<https://huggingface.co/megatron-11b>`__ architecture.

Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.


Args:
vocab_size (:obj:`int`, `optional`, defaults to 50265):
Vocabulary size of the Megatron model. Defines the number of different tokens that can be represented by
the :obj:`inputs_ids` passed when calling :class:`~transformers.MegatronModel` or
:class:`~transformers.TFMegatronModel`.
d_model (:obj:`int`, `optional`, defaults to 1024):
Dimensionality of the layers and the pooler layer.
decoder_layers (:obj:`int`, `optional`, defaults to 12):
Number of decoder layers.
decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (:obj:`int`, `optional`, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).

Example::

>>> from transformers import MegatronModel, MegatronConfig

>>> # Initializing a Megatron megatron-11b style configuration
>>> configuration = MegatronConfig()

>>> # Initializing a model from the megatron-11b style configuration
>>> model = MegatronModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "megatron"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=51200,
max_position_embeddings=1024,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=False,
is_decoder=True,
activation_function="gelu",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
decoder_start_token_id=0,
classifier_dropout=0.0,
scale_embedding=True,
gradient_checkpointing=False,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
is_decoder=is_decoder,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)

self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True

@property
def num_attention_heads(self) -> int:
return self.decoder_attention_heads

@property
def hidden_size(self) -> int:
return self.d_model
Loading