diff --git a/src/transformers/models/gpt_neo/configuration_gpt_neo.py b/src/transformers/models/gpt_neo/configuration_gpt_neo.py index 1e9995dfc7b6..85b4f702d512 100644 --- a/src/transformers/models/gpt_neo/configuration_gpt_neo.py +++ b/src/transformers/models/gpt_neo/configuration_gpt_neo.py @@ -15,7 +15,7 @@ """ GPT Neo model configuration """ from collections import OrderedDict -from typing import Any, Mapping, Optional +from typing import Any, Dict, Iterable, Mapping, Optional from ... import PreTrainedTokenizer, TensorType, is_torch_available from ...configuration_utils import PretrainedConfig @@ -253,8 +253,12 @@ def _number_key_values(self): def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) if self.use_past: - for i in range(self._number_key_values): - common_inputs[f"past_key_values.{i}"] = self._key_values_dynamic_axis[i] + for i in range(self._config.num_layers): + if self._config.attention_layers[i] == "local": + common_inputs[f"past_key_values.{i}.key_value"] = {0: "batch", 1: "sequence"} + else: + common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "sequence"} + common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "sequence"} common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} @@ -264,9 +268,12 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: def outputs(self) -> Mapping[str, Mapping[int, str]]: common_outputs = super().outputs if self.use_past: - for i in range(self._number_key_values): - common_outputs[f"present.{i}"] = self._key_values_dynamic_axis[i] - + for i in range(self._config.num_layers): + if self._config.attention_layers[i] == "local": + common_outputs[f"present.{i}.key_value"] = {0: "batch", 1: "sequence"} + else: + common_outputs[f"present.{i}.key"] = {0: "batch", 2: "sequence"} + common_outputs[f"present.{i}.value"] = {0: "batch", 2: "sequence"} return common_outputs def generate_dummy_inputs( @@ -315,3 +322,18 @@ def generate_dummy_inputs( ) return ordered_inputs + + @staticmethod + def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: + if name in ["present", "past_key_values"]: + flatten_output = {} + for idx, t in enumerate(field): + if len(t) == 1: + flatten_output[f"{name}.{idx}.key_value"] = t[0] + else: + flatten_output[f"{name}.{idx}.key"] = t[0] + flatten_output[f"{name}.{idx}.value"] = t[1] + + return flatten_output + + return super().flatten_output_collection_property(name, field) diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index 5a6feb5d8eca..1a3c112503c7 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -14,10 +14,11 @@ # limitations under the License. """ T5 model configuration """ from collections import OrderedDict -from typing import Any, Mapping, Optional +from typing import Any, Dict, Iterable, Mapping, Optional from transformers import PreTrainedTokenizer, TensorType +from ... import is_torch_available from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfigWithPast from ...utils import logging @@ -140,9 +141,6 @@ def num_hidden_layers(self): class T5OnnxConfig(OnnxConfigWithPast): - def __init__(self, config: PretrainedConfig, use_past: bool = False): - super().__init__(config, use_past) - @property def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = OrderedDict( @@ -155,29 +153,30 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: ) if self.use_past: - for i in range(self._config.num_layers): - common_inputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "past_sequence"},) - common_inputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "past_sequence"},) - common_inputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "past_sequence"},) - common_inputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "past_sequence"},) + for i in range(0, self._config.num_layers): + common_inputs[f"past_key_values.{i}.decoder.key"] = {0: "batch", 2: "past_sequence"} + common_inputs[f"past_key_values.{i}.decoder.value"] = {0: "batch", 2: "past_sequence"} + common_inputs[f"past_key_values.{i}.encoder.key"] = {0: "batch", 2: "past_sequence"} + common_inputs[f"past_key_values.{i}.encoder.value"] = {0: "batch", 2: "past_sequence"} return common_inputs @property def outputs(self) -> Mapping[str, Mapping[int, str]]: - common_outputs = OrderedDict( - [ - ("last_hidden_state", {0: "batch", 1: "decoder_sequence"}), - ("encoder_last_hidden_state", {0: "batch", 2: "encoder_sequence"}), - ] - ) + common_outputs = super().outputs + + if "last_hidden_state" in common_outputs: + common_outputs["last_hidden_state"] = {0: "batch", 1: "decoder_sequence"} if self.use_past: for i in range(self._config.num_layers): - common_outputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "decoder_sequence"},) - common_outputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "decoder_sequence"},) - common_outputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "encoder_sequence"},) - common_outputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "encoder_sequence"},) + common_outputs[f"present.{i}.decoder.key"] = {0: "batch", 2: "decoder_sequence"} + common_outputs[f"present.{i}.decoder.value"] = {0: "batch", 2: "decoder_sequence"} + common_outputs[f"present.{i}.encoder.key"] = {0: "batch", 2: "encoder_sequence"} + common_outputs[f"present.{i}.encoder.value"] = {0: "batch", 2: "encoder_sequence"} + + if self.task == "default": + common_outputs["encoder_last_hidden_state"] = {0: "batch", 2: "encoder_sequence"} return common_outputs @@ -189,8 +188,6 @@ def generate_dummy_inputs( is_pair: bool = False, framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: - if self.use_past: - raise NotImplementedError() # Generate encoder inputs encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) @@ -199,4 +196,45 @@ def generate_dummy_inputs( decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework) decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} - return dict(**encoder_inputs, **decoder_inputs) + ordered_inputs = dict(**encoder_inputs, **decoder_inputs) + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch = encoder_inputs["input_ids"].shape[0] + encoder_seq_length = encoder_inputs["input_ids"].shape[1] + encoder_shape = ( + batch, + self._config.num_heads, + encoder_seq_length, + self._config.hidden_size // self._config.num_heads, + ) + decoder_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads) + + ordered_inputs["past_key_values"] = [] + for _ in range(self._config.num_layers): + ordered_inputs["past_key_values"].append( + ( + torch.zeros(decoder_shape), + torch.zeros(decoder_shape), + torch.zeros(encoder_shape), + torch.zeros(encoder_shape), + ) + ) + + return ordered_inputs + + @staticmethod + def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: + if name in ["present", "past_key_values"]: + flatten_output = {} + for idx, t in enumerate(field): + flatten_output[f"{name}.{idx}.decoder.key"] = t[0] + flatten_output[f"{name}.{idx}.decoder.value"] = t[1] + flatten_output[f"{name}.{idx}.encoder.key"] = t[2] + flatten_output[f"{name}.{idx}.encoder.value"] = t[3] + + return flatten_output + + return super().flatten_output_collection_property(name, field) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 948b3ed33bc2..8f07f04c5508 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -429,8 +429,6 @@ def forward( # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) batch_size, seq_length = hidden_states.shape[:2] - int_seq_length = int(seq_length) - real_seq_length = seq_length if past_key_value is not None: @@ -499,7 +497,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # if key and values are already calculated # we want only the last query position bias if past_key_value is not None: - position_bias = position_bias[:, :, -int_seq_length:, :] + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] if mask is not None: position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) @@ -629,7 +627,7 @@ def forward( if len(past_key_value) != expected_num_past_key_values: raise ValueError( f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}." + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" f"Got {len(past_key_value)} past key / value states" ) diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index 56512d3652d8..8e9e1575b1e7 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -14,7 +14,7 @@ import dataclasses from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, Callable, List, Mapping, Optional +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType @@ -59,6 +59,7 @@ class OnnxConfig(ABC): _TASKS_TO_COMMON_OUTPUTS = { "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), "causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), + "seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}), "sequence-classification": OrderedDict({"logits": {0: "batch"}}), "token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "multiple-choice": OrderedDict({"logits": {0: "batch"}}), @@ -228,6 +229,24 @@ def restore_ops(self): orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op) setattr(spec.o, spec.name, orig_op) + @staticmethod + def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: + """ + Flatten any potential nested structure expanding the name of the field with the index of the element within the + structure. + + Args: + name: The name of the nested structure + field: The structure to, potentially, be flattened + + Returns: + (Dict[str, Any]): Outputs with flattened structure and key mapping this new structure. + + """ + from itertools import chain + + return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))} + class OnnxConfigWithPast(OnnxConfig, ABC): def __init__( @@ -285,3 +304,15 @@ def generate_dummy_inputs( # Generate dummy inputs according to compute batch and sequence dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework))) + + @staticmethod + def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: + if name in ["present", "past_key_values"]: + flatten_output = {} + for idx, t in enumerate(field): + flatten_output[f"{name}.{idx}.key"] = t[0] + flatten_output[f"{name}.{idx}.value"] = t[1] + + return flatten_output + + return super().flatten_output_collection_property(name, field) diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index 86491a7f5b87..62fa1191c382 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -24,7 +24,6 @@ from ..file_utils import is_torch_onnx_dict_inputs_support_available from ..utils import logging from .config import OnnxConfig -from .utils import flatten_output_collection_property logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -163,7 +162,7 @@ def validate_model_outputs( if name == "past_key_values": name = "present" if isinstance(value, (list, tuple)): - value = flatten_output_collection_property(name, value) + value = config.flatten_output_collection_property(name, value) ref_outputs_dict.update(value) else: ref_outputs_dict[name] = value @@ -172,7 +171,7 @@ def validate_model_outputs( onnx_inputs = {} for name, value in reference_model_inputs.items(): if isinstance(value, (list, tuple)): - value = flatten_output_collection_property(name, value) + value = config.flatten_output_collection_property(name, value) onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()}) else: onnx_inputs[name] = value.numpy() diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 8d06176cbbd5..530f68b20a1e 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -21,6 +21,7 @@ AutoModelForCausalLM, AutoModelForMultipleChoice, AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, ) @@ -46,6 +47,7 @@ class FeaturesManager: _TASKS_TO_AUTOMODELS = { "default": AutoModel, "causal-lm": AutoModelForCausalLM, + "seq2seq-lm": AutoModelForSeq2SeqLM, "sequence-classification": AutoModelForSequenceClassification, "token-classification": AutoModelForTokenClassification, "multiple-choice": AutoModelForMultipleChoice, @@ -61,7 +63,9 @@ class FeaturesManager: "gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig), "longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig), "roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig), - "t5": supported_features_mapping("default", onnx_config_cls=T5OnnxConfig), + "t5": supported_features_mapping( + "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig + ), "xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig), "gpt-neo": supported_features_mapping( "default", diff --git a/src/transformers/onnx/utils.py b/src/transformers/onnx/utils.py index b32c99119ddc..def160e6c7bb 100644 --- a/src/transformers/onnx/utils.py +++ b/src/transformers/onnx/utils.py @@ -14,7 +14,6 @@ from ctypes import c_float, sizeof from enum import Enum -from typing import Any, Dict, Iterable class ParameterFormat(Enum): @@ -62,21 +61,3 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm Size (in byte) taken to save all the parameters """ return num_parameters * dtype.size - - -def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: - """ - Flatten any potential nested structure expanding the name of the field with the index of the element within the - structure. - - Args: - name: The name of the nested structure - field: The structure to, potentially, be flattened - - Returns: - (Dict[str, Any]): Outputs with flattened structure and key mapping this new structure. - - """ - from itertools import chain - - return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))} diff --git a/tests/test_onnx_v2.py b/tests/test_onnx_v2.py index 6a2ce960d787..4bf7529514ad 100644 --- a/tests/test_onnx_v2.py +++ b/tests/test_onnx_v2.py @@ -34,11 +34,7 @@ validate_model_outputs, ) from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast -from transformers.onnx.utils import ( - compute_effective_axis_dimension, - compute_serialized_parameters_size, - flatten_output_collection_property, -) +from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size from transformers.testing_utils import require_onnx, require_torch, slow @@ -95,7 +91,7 @@ def test_flatten_output_collection_property(self): ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n} """ self.assertEqual( - flatten_output_collection_property("past_key", [[0], [1], [2]]), + OnnxConfig.flatten_output_collection_property("past_key", [[0], [1], [2]]), { "past_key.0": 0, "past_key.1": 1,