Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5 with past ONNX export #13014

Merged
merged 3 commits into from
Aug 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions src/transformers/models/gpt_neo/configuration_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
""" GPT Neo model configuration """

from collections import OrderedDict
from typing import Any, Mapping, Optional
from typing import Any, Dict, Iterable, Mapping, Optional

from ... import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -253,8 +253,12 @@ def _number_key_values(self):
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
for i in range(self._number_key_values):
common_inputs[f"past_key_values.{i}"] = self._key_values_dynamic_axis[i]
for i in range(self._config.num_layers):
if self._config.attention_layers[i] == "local":
common_inputs[f"past_key_values.{i}.key_value"] = {0: "batch", 1: "sequence"}
else:
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "sequence"}

common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}

Expand All @@ -264,9 +268,12 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super().outputs
if self.use_past:
for i in range(self._number_key_values):
common_outputs[f"present.{i}"] = self._key_values_dynamic_axis[i]

for i in range(self._config.num_layers):
if self._config.attention_layers[i] == "local":
common_outputs[f"present.{i}.key_value"] = {0: "batch", 1: "sequence"}
else:
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "sequence"}
return common_outputs

def generate_dummy_inputs(
Expand Down Expand Up @@ -315,3 +322,18 @@ def generate_dummy_inputs(
)

return ordered_inputs

@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
if len(t) == 1:
flatten_output[f"{name}.{idx}.key_value"] = t[0]
else:
flatten_output[f"{name}.{idx}.key"] = t[0]
flatten_output[f"{name}.{idx}.value"] = t[1]

return flatten_output

return super().flatten_output_collection_property(name, field)
82 changes: 60 additions & 22 deletions src/transformers/models/t5/configuration_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# limitations under the License.
""" T5 model configuration """
from collections import OrderedDict
from typing import Any, Mapping, Optional
from typing import Any, Dict, Iterable, Mapping, Optional

from transformers import PreTrainedTokenizer, TensorType

from ... import is_torch_available
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast
from ...utils import logging
Expand Down Expand Up @@ -140,9 +141,6 @@ def num_hidden_layers(self):


class T5OnnxConfig(OnnxConfigWithPast):
def __init__(self, config: PretrainedConfig, use_past: bool = False):
super().__init__(config, use_past)

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict(
Expand All @@ -155,29 +153,30 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
)

if self.use_past:
for i in range(self._config.num_layers):
common_inputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "past_sequence"},)
common_inputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "past_sequence"},)
common_inputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "past_sequence"},)
common_inputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "past_sequence"},)
for i in range(0, self._config.num_layers):
common_inputs[f"past_key_values.{i}.decoder.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.decoder.value"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.encoder.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.encoder.value"] = {0: "batch", 2: "past_sequence"}

return common_inputs

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "decoder_sequence"}),
("encoder_last_hidden_state", {0: "batch", 2: "encoder_sequence"}),
]
)
common_outputs = super().outputs

if "last_hidden_state" in common_outputs:
common_outputs["last_hidden_state"] = {0: "batch", 1: "decoder_sequence"}

if self.use_past:
for i in range(self._config.num_layers):
common_outputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "decoder_sequence"},)
common_outputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "decoder_sequence"},)
common_outputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "encoder_sequence"},)
common_outputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "encoder_sequence"},)
common_outputs[f"present.{i}.decoder.key"] = {0: "batch", 2: "decoder_sequence"}
common_outputs[f"present.{i}.decoder.value"] = {0: "batch", 2: "decoder_sequence"}
common_outputs[f"present.{i}.encoder.key"] = {0: "batch", 2: "encoder_sequence"}
common_outputs[f"present.{i}.encoder.value"] = {0: "batch", 2: "encoder_sequence"}

if self.task == "default":
common_outputs["encoder_last_hidden_state"] = {0: "batch", 2: "encoder_sequence"}

return common_outputs

Expand All @@ -189,8 +188,6 @@ def generate_dummy_inputs(
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
if self.use_past:
raise NotImplementedError()

# Generate encoder inputs
encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
Expand All @@ -199,4 +196,45 @@ def generate_dummy_inputs(
decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}

return dict(**encoder_inputs, **decoder_inputs)
ordered_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch = encoder_inputs["input_ids"].shape[0]
encoder_seq_length = encoder_inputs["input_ids"].shape[1]
encoder_shape = (
batch,
self._config.num_heads,
encoder_seq_length,
self._config.hidden_size // self._config.num_heads,
)
decoder_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads)

ordered_inputs["past_key_values"] = []
for _ in range(self._config.num_layers):
ordered_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)

return ordered_inputs

@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
flatten_output[f"{name}.{idx}.decoder.key"] = t[0]
flatten_output[f"{name}.{idx}.decoder.value"] = t[1]
flatten_output[f"{name}.{idx}.encoder.key"] = t[2]
flatten_output[f"{name}.{idx}.encoder.value"] = t[3]

return flatten_output

return super().flatten_output_collection_property(name, field)
6 changes: 2 additions & 4 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,6 @@ def forward(
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]

int_seq_length = int(seq_length)

real_seq_length = seq_length

if past_key_value is not None:
Expand Down Expand Up @@ -499,7 +497,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -int_seq_length:, :]
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
Expand Down Expand Up @@ -629,7 +627,7 @@ def forward(
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}."
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)

Expand Down
33 changes: 32 additions & 1 deletion src/transformers/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import dataclasses
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Callable, List, Mapping, Optional
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional

from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType

Expand Down Expand Up @@ -59,6 +59,7 @@ class OnnxConfig(ABC):
_TASKS_TO_COMMON_OUTPUTS = {
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
Expand Down Expand Up @@ -228,6 +229,24 @@ def restore_ops(self):
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
setattr(spec.o, spec.name, orig_op)

@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
"""
Flatten any potential nested structure expanding the name of the field with the index of the element within the
structure.

Args:
name: The name of the nested structure
field: The structure to, potentially, be flattened

Returns:
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.

"""
from itertools import chain

return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}


class OnnxConfigWithPast(OnnxConfig, ABC):
def __init__(
Expand Down Expand Up @@ -285,3 +304,15 @@ def generate_dummy_inputs(
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework)))

@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
flatten_output[f"{name}.{idx}.key"] = t[0]
flatten_output[f"{name}.{idx}.value"] = t[1]

return flatten_output

return super().flatten_output_collection_property(name, field)
5 changes: 2 additions & 3 deletions src/transformers/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ..file_utils import is_torch_onnx_dict_inputs_support_available
from ..utils import logging
from .config import OnnxConfig
from .utils import flatten_output_collection_property


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -163,7 +162,7 @@ def validate_model_outputs(
if name == "past_key_values":
name = "present"
if isinstance(value, (list, tuple)):
value = flatten_output_collection_property(name, value)
value = config.flatten_output_collection_property(name, value)
ref_outputs_dict.update(value)
else:
ref_outputs_dict[name] = value
Expand All @@ -172,7 +171,7 @@ def validate_model_outputs(
onnx_inputs = {}
for name, value in reference_model_inputs.items():
if isinstance(value, (list, tuple)):
value = flatten_output_collection_property(name, value)
value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
else:
onnx_inputs[name] = value.numpy()
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
AutoModelForCausalLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
Expand All @@ -46,6 +47,7 @@ class FeaturesManager:
_TASKS_TO_AUTOMODELS = {
"default": AutoModel,
"causal-lm": AutoModelForCausalLM,
"seq2seq-lm": AutoModelForSeq2SeqLM,
"sequence-classification": AutoModelForSequenceClassification,
"token-classification": AutoModelForTokenClassification,
"multiple-choice": AutoModelForMultipleChoice,
Expand All @@ -61,7 +63,9 @@ class FeaturesManager:
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
"roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
"t5": supported_features_mapping("default", onnx_config_cls=T5OnnxConfig),
"t5": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
),
"xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
"gpt-neo": supported_features_mapping(
"default",
Expand Down
19 changes: 0 additions & 19 deletions src/transformers/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from ctypes import c_float, sizeof
from enum import Enum
from typing import Any, Dict, Iterable


class ParameterFormat(Enum):
Expand Down Expand Up @@ -62,21 +61,3 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
Size (in byte) taken to save all the parameters
"""
return num_parameters * dtype.size


def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
"""
Flatten any potential nested structure expanding the name of the field with the index of the element within the
structure.

Args:
name: The name of the nested structure
field: The structure to, potentially, be flattened

Returns:
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.

"""
from itertools import chain

return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
8 changes: 2 additions & 6 deletions tests/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@
validate_model_outputs,
)
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
from transformers.onnx.utils import (
compute_effective_axis_dimension,
compute_serialized_parameters_size,
flatten_output_collection_property,
)
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_torch, slow


Expand Down Expand Up @@ -95,7 +91,7 @@ def test_flatten_output_collection_property(self):
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
"""
self.assertEqual(
flatten_output_collection_property("past_key", [[0], [1], [2]]),
OnnxConfig.flatten_output_collection_property("past_key", [[0], [1], [2]]),
{
"past_key.0": 0,
"past_key.1": 1,
Expand Down