From 36ea9f4b2611b194e27a1a2a9fc9654d56436ad3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Mon, 20 Feb 2023 16:14:23 +0100 Subject: [PATCH 1/3] support IO Binding for merged decoder --- optimum/onnxruntime/base.py | 53 +++++++++++++------------ optimum/onnxruntime/modeling_decoder.py | 3 ++ 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index 4f1fb2cf23..05b1832aef 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -157,21 +157,18 @@ def __init__( def prepare_inputs_for_merged( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None ): - # Prepare use cache - if past_key_values is None and self.parent_model.use_merged: # Uses "no past" branch of a merged decoder + if past_key_values is None and self.parent_model.use_merged: + # Uses "no past" branch of a merged decoder use_cache_branch = torch.full((1,), False).to(self.device) - elif ( - past_key_values is not None and self.parent_model.use_merged - ): # Uses "with past" branch of a merged decoder + elif past_key_values is not None and self.parent_model.use_merged: + # Uses "with past" branch of a merged decoder use_cache_branch = torch.full((1,), True).to(self.device) - else: # Uses separate decoders + else: + # Uses separate decoders use_cache_branch = None - # Prepare past key values - is_dummy = False - if ( - self.parent_model.use_merged and past_key_values is None - ): # Generate dummy past for the first forward if uses a merged decoder + # Generate dummy past for the first forward if uses a merged decoder + if self.parent_model.use_merged and past_key_values is None: batch_size = input_ids.size(0) num_attention_heads = self.normalized_config.num_attention_heads embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads @@ -190,12 +187,14 @@ def prepare_inputs_for_merged( shape = (batch_size, num_attention_heads, 1, embed_size_per_head) key_or_value = torch.zeros(shape, dtype=torch.float32).to(self.device) past_key_values = [key_or_value for _ in range(len(self.key_value_input_names))] - is_dummy = True - return use_cache_branch, past_key_values, is_dummy + return use_cache_branch, past_key_values def compute_past_key_values_output_shapes( - self, input_ids: torch.Tensor, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + self, + input_ids: torch.Tensor, + use_cache_branch: Optional[bool], + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, ) -> Dict[str, List[int]]: """ Computes the outputs of the past key / value because it is not always easy to perform shape inference on them, @@ -204,6 +203,9 @@ def compute_past_key_values_output_shapes( Args: input_ids (`torch.Tensor`): The input ids that are associated with the current inputs. + use_cache_branch (`bool`): + In the case of a merged decoder, whether the with-past branch is used. In case the decoders without and with past are + separate, this parameter should be None. past_key_values (`Optional[Tuple[Tuple[torch.Tensor]]]`, defaults to `None`): The past key values associated with the current inputs. @@ -213,8 +215,11 @@ def compute_past_key_values_output_shapes( batch_size = input_ids.size(0) num_attention_heads = self.normalized_config.num_attention_heads embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads + sequence_length = input_ids.size(1) - if past_key_values is not None: + if past_key_values is not None and use_cache_branch is not False: + # Here, use_cache_branch may be None in the case of separate decoder without/with past, or True if the with past branch + # of a merged decoder is used sequence_length += past_key_values[0].size(2) half_shape = [batch_size, num_attention_heads] @@ -242,20 +247,17 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, ) -> CausalLMOutputWithCrossAttentions: - known_output_shapes = {} # Flatten the past_key_values if past_key_values is not None: past_key_values = [past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer] + # no-ops if merged decoder is not used + use_cache_branch, past_key_values = self.prepare_inputs_for_merged(input_ids, past_key_values) + if self.device.type == "cuda" and self.parent_model.use_io_binding: - # TODO: support merged decoder with IO Binding - if self.parent_model.use_merged is True: - raise ValueError( - "Merged decoder without / with past is currently not supported when using IO Binding but will be in a future release." - " Please disable IO Binding setting model.use_io_binding = False, or use a model that does not merge the decoder." - ) known_output_shapes = self.compute_past_key_values_output_shapes( input_ids, + use_cache_branch=use_cache_branch.item() if use_cache_branch is not None else None, past_key_values=past_key_values, ) @@ -267,6 +269,9 @@ def forward( if past_key_values is not None: model_inputs += past_key_values + if use_cache_branch is not None: + model_inputs.append(use_cache_branch) + if "labels" in self.input_names: model_inputs.append(labels) known_output_shapes.update({"loss": []}) @@ -302,9 +307,6 @@ def forward( } if self.parent_model.use_merged is True: - use_cache_branch, past_key_values, is_dummy = self.prepare_inputs_for_merged( - input_ids, past_key_values - ) onnx_inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy() if past_key_values is not None: @@ -371,7 +373,6 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, ) -> Seq2SeqLMOutput: - known_output_shapes = {} # Flatten the past_key_values if past_key_values is not None: past_key_values = tuple( diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index fa54a6ec43..2a8876b40a 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -589,8 +589,11 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, labels: Optional[torch.LongTensor] = None, + use_cache_branch: None = None, **kwargs, ) -> CausalLMOutputWithCrossAttentions: + # adding use_cache_branch in the signature here is just a hack for IO Binding + if past_key_values is None or self.use_cache is False: outputs = self.decoder( input_ids=input_ids, From 3f93aa727207b62ecafbd3399be1e5f395f81b15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Mon, 20 Feb 2023 16:18:29 +0100 Subject: [PATCH 2/3] fix test --- tests/onnxruntime/test_modeling.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 8dc6cf0a76..14b4f3f2b3 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2121,16 +2121,27 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode self.assertTrue(torch.equal(outputs_model_merged, outputs_model_not_merged)) - @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]})) + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]}) + ) @require_torch_gpu @pytest.mark.gpu_test - def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: bool): - model_args = {"test_name": test_name, "model_arch": model_arch, "use_cache": use_cache} + def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): + model_args = { + "test_name": test_name, + "model_arch": model_arch, + "use_cache": use_cache, + "use_merged": use_merged, + } self._setup(model_args) model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForCausalLM.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False) - io_model = ORTModelForCausalLM.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True) + onnx_model = ORTModelForCausalLM.from_pretrained( + self.onnx_model_dirs[test_name], use_cache=use_cache, use_io_binding=False + ).to("cuda") + io_model = ORTModelForCausalLM.from_pretrained( + self.onnx_model_dirs[test_name], use_cache=use_cache, use_io_binding=True + ).to("cuda") tokenizer = get_preprocessor(model_id) tokens = tokenizer(["This is a sample output"] * 2, return_tensors="pt").to("cuda") From a932fc391571d5062a62c17cc309b9c084d1072f Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 20 Feb 2023 16:28:13 +0100 Subject: [PATCH 3/3] Update optimum/onnxruntime/base.py --- optimum/onnxruntime/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index 05b1832aef..2b4ad6a5ba 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -203,7 +203,7 @@ def compute_past_key_values_output_shapes( Args: input_ids (`torch.Tensor`): The input ids that are associated with the current inputs. - use_cache_branch (`bool`): + use_cache_branch (`Optional[bool]`): In the case of a merged decoder, whether the with-past branch is used. In case the decoders without and with past are separate, this parameter should be None. past_key_values (`Optional[Tuple[Tuple[torch.Tensor]]]`, defaults to `None`):