Skip to content

Commit

Permalink
GPT-Neo ONNX export (#12911)
Browse files Browse the repository at this point in the history
GPT-Neo ONNX export and task / feature refactoring

Authored-by: Michael Benayoun <michael@huggingface.co>
  • Loading branch information
michaelbenayoun authored and LysandreJik committed Aug 9, 2021
1 parent 2c255a2 commit 94b7db9
Show file tree
Hide file tree
Showing 9 changed files with 380 additions and 103 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/gpt_neo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


_import_structure = {
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig", "GPTNeoOnnxConfig"],
}

if is_torch_available():
Expand All @@ -43,7 +43,7 @@


if TYPE_CHECKING:
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig, GPTNeoOnnxConfig

if is_torch_available():
from .modeling_gpt_neo import (
Expand Down
142 changes: 142 additions & 0 deletions src/transformers/models/gpt_neo/configuration_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
# limitations under the License.
""" GPT Neo model configuration """

from collections import OrderedDict
from typing import Any, Mapping, Optional

from ... import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec
from ...utils import logging


Expand Down Expand Up @@ -173,3 +178,140 @@ def num_attention_heads(self):
@property
def num_hidden_layers(self):
return self.num_layers


def custom_unfold(input, dimension, size, step):
"""Custom torch.Tensor.unfold implementation to enable the export to ONNX."""
import torch

shape = input.size()
rank = len(shape)
sizedim = shape[dimension]

low_indices = torch.arange(0, sizedim, step)
min_length = torch.div(sizedim - size, step, rounding_mode="floor") + 1
indices = torch.arange(size) + low_indices[:min_length][:, None]

s = [slice(None)] * rank
s[dimension] = indices
sliced = input[s]

perm = list(range(0, rank + 1))
perm.append(perm.pop(dimension + 1))

return sliced.permute(perm)


def custom_get_block_length_and_num_blocks(seq_length, window_size):
"""
Custom implementation for GPTNeoAttentionMixin._get_block_length_and_num_blocks to enable the export to ONNX as
original implmentation uses Python variables and control flow.
"""
import torch

candidates = torch.arange(1, window_size)
remainders = torch.remainder(seq_length, candidates)
divisor_indices = remainders == 0
divisors = candidates[divisor_indices]
largest_divisor = torch.max(divisors)
return largest_divisor, torch.div(seq_length, largest_divisor, rounding_mode="floor")


class GPTNeoOnnxConfig(OnnxConfigWithPast):
def __init__(self, config: PretrainedConfig, task: str = "default", use_past: bool = False):
if is_torch_available():
import torch

from .modeling_gpt_neo import GPTNeoAttentionMixin

patching_specs = [
PatchingSpec(torch.Tensor, name="unfold", custom_op=custom_unfold),
PatchingSpec(
GPTNeoAttentionMixin,
name="_get_block_length_and_num_blocks",
custom_op=custom_get_block_length_and_num_blocks,
op_wrapper=staticmethod,
),
]

super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)

self._num_local_attention = len([type_ for type_ in self._config.attention_layers if type_ == "local"])
self._key_values_dynamic_axis = []
for i in range(self._config.num_layers):
if self._config.attention_layers[i] == "local":
self._key_values_dynamic_axis.append({0: "batch", 1: "sequence"})
else:
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})

@property
def _number_key_values(self):
return (self._config.num_layers * 2) - self._num_local_attention

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
for i in range(self._number_key_values):
common_inputs[f"past_key_values.{i}"] = self._key_values_dynamic_axis[i]

common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}

return common_inputs

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super().outputs
if self.use_past:
for i in range(self._number_key_values):
common_outputs[f"present.{i}"] = self._key_values_dynamic_axis[i]

return common_outputs

def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)

# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})

batch = common_inputs["input_ids"].shape[0]
past_shapes = {
"global": (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_attention_heads),
"local": (batch, 1, self._config.hidden_size),
}

# Need to add the past_keys
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch

ordered_inputs["past_key_values"] = []
for i in range(self._config.num_layers):
attention_type = self._config.attention_layers[i]
if attention_type == "global":
ordered_inputs["past_key_values"].append(
(
torch.zeros(past_shapes[attention_type]),
torch.zeros(past_shapes[attention_type]),
)
)
else:
ordered_inputs["past_key_values"].append((torch.zeros(past_shapes[attention_type]),))

ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.zeros(batch, 1)], dim=1
)

return ordered_inputs
2 changes: 1 addition & 1 deletion src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,7 @@ def forward(
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[range(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size), sequence_lengths]

loss = None
if labels is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast, PatchingSpec
from .convert import export, validate_model_outputs
from .utils import ParameterFormat, compute_serialized_parameters_size
91 changes: 6 additions & 85 deletions src/transformers/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,101 +14,22 @@

from argparse import ArgumentParser
from pathlib import Path
from typing import Callable, Tuple

from transformers.models.albert import AlbertOnnxConfig
from transformers.models.auto import AutoTokenizer
from transformers.models.bart import BartOnnxConfig
from transformers.models.bert import BertOnnxConfig
from transformers.models.distilbert import DistilBertOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig

from .. import is_torch_available
from ..utils import logging
from .convert import export, validate_model_outputs


if is_torch_available():
from transformers import AutoModel, PreTrainedModel

FEATURES_TO_AUTOMODELS = {
"default": AutoModel,
}


# Set of model topologies we support associated to the features supported by each topology and the factory
SUPPORTED_MODEL_KIND = {
"albert": {"default": AlbertOnnxConfig.default},
"bart": {"default": BartOnnxConfig.default},
"bert": {"default": BertOnnxConfig.default},
"distilbert": {"default": DistilBertOnnxConfig.default},
"gpt2": {"default": GPT2OnnxConfig.default},
"longformer": {"default": LongformerOnnxConfig.default},
"roberta": {"default": RobertaOnnxConfig},
"t5": {"default": T5OnnxConfig.default},
"xlm-roberta": {"default": XLMRobertaOnnxConfig.default},
}


def get_model_from_features(features: str, model: str):
"""
Attempt to retrieve a model from a model's name and the features to be enabled.
Args:
features: The features required
model: The name of the model to export
Returns:
"""
if features not in FEATURES_TO_AUTOMODELS:
raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}")

return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)


def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features
Args:
model: The model to export
features: The name of the features to check if they are avaiable
Returns:
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
"""
if model.config.model_type not in SUPPORTED_MODEL_KIND:
raise KeyError(
f"{model.config.model_type} ({model.name}) is not supported yet. "
f"Only {SUPPORTED_MODEL_KIND} are supported. "
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
)

# Look for the features
model_features = SUPPORTED_MODEL_KIND[model.config.model_type]
if features not in model_features:
raise ValueError(
f"{model.config.model_type} doesn't support features {features}. "
f"Supported values are: {list(model_features.keys())}"
)

return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features]
from .features import FeaturesManager


def main():
parser = ArgumentParser("Hugging Face ONNX Exporter tool")
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
parser.add_argument(
"--features",
choices=["default"],
"--feature",
choices=list(FeaturesManager.AVAILABLE_FEATURES),
default="default",
help="Export the model with some additional features.",
help="Export the model with some additional feature.",
)
parser.add_argument(
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
Expand All @@ -127,8 +48,8 @@ def main():

# Allocate the model
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = get_model_from_features(args.features, args.model)
model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features)
model = FeaturesManager.get_model_from_feature(args.feature, args.model)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
onnx_config = model_onnx_config(model.config)

# Ensure the requested opset is sufficient
Expand Down
Loading

0 comments on commit 94b7db9

Please sign in to comment.