Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support IO Binding for merged decoder #797

Merged
merged 3 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 27 additions & 26 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,21 +157,18 @@ def __init__(
def prepare_inputs_for_merged(
self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None
):
# Prepare use cache
if past_key_values is None and self.parent_model.use_merged: # Uses "no past" branch of a merged decoder
if past_key_values is None and self.parent_model.use_merged:
# Uses "no past" branch of a merged decoder
use_cache_branch = torch.full((1,), False).to(self.device)
elif (
past_key_values is not None and self.parent_model.use_merged
): # Uses "with past" branch of a merged decoder
elif past_key_values is not None and self.parent_model.use_merged:
# Uses "with past" branch of a merged decoder
use_cache_branch = torch.full((1,), True).to(self.device)
else: # Uses separate decoders
else:
# Uses separate decoders
use_cache_branch = None

# Prepare past key values
is_dummy = False
if (
self.parent_model.use_merged and past_key_values is None
): # Generate dummy past for the first forward if uses a merged decoder
# Generate dummy past for the first forward if uses a merged decoder
if self.parent_model.use_merged and past_key_values is None:
batch_size = input_ids.size(0)
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads
Expand All @@ -190,12 +187,14 @@ def prepare_inputs_for_merged(
shape = (batch_size, num_attention_heads, 1, embed_size_per_head)
key_or_value = torch.zeros(shape, dtype=torch.float32).to(self.device)
past_key_values = [key_or_value for _ in range(len(self.key_value_input_names))]
is_dummy = True

return use_cache_branch, past_key_values, is_dummy
return use_cache_branch, past_key_values

def compute_past_key_values_output_shapes(
self, input_ids: torch.Tensor, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
self,
input_ids: torch.Tensor,
use_cache_branch: Optional[bool],
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
) -> Dict[str, List[int]]:
"""
Computes the outputs of the past key / value because it is not always easy to perform shape inference on them,
Expand All @@ -204,6 +203,9 @@ def compute_past_key_values_output_shapes(
Args:
input_ids (`torch.Tensor`):
The input ids that are associated with the current inputs.
use_cache_branch (`Optional[bool]`):
In the case of a merged decoder, whether the with-past branch is used. In case the decoders without and with past are
separate, this parameter should be None.
past_key_values (`Optional[Tuple[Tuple[torch.Tensor]]]`, defaults to `None`):
The past key values associated with the current inputs.

Expand All @@ -213,8 +215,11 @@ def compute_past_key_values_output_shapes(
batch_size = input_ids.size(0)
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads

sequence_length = input_ids.size(1)
if past_key_values is not None:
if past_key_values is not None and use_cache_branch is not False:
# Here, use_cache_branch may be None in the case of separate decoder without/with past, or True if the with past branch
# of a merged decoder is used
sequence_length += past_key_values[0].size(2)

half_shape = [batch_size, num_attention_heads]
Expand Down Expand Up @@ -242,20 +247,17 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
) -> CausalLMOutputWithCrossAttentions:
known_output_shapes = {}
# Flatten the past_key_values
if past_key_values is not None:
past_key_values = [past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer]

# no-ops if merged decoder is not used
use_cache_branch, past_key_values = self.prepare_inputs_for_merged(input_ids, past_key_values)

if self.device.type == "cuda" and self.parent_model.use_io_binding:
# TODO: support merged decoder with IO Binding
if self.parent_model.use_merged is True:
raise ValueError(
"Merged decoder without / with past is currently not supported when using IO Binding but will be in a future release."
" Please disable IO Binding setting model.use_io_binding = False, or use a model that does not merge the decoder."
)
known_output_shapes = self.compute_past_key_values_output_shapes(
input_ids,
use_cache_branch=use_cache_branch.item() if use_cache_branch is not None else None,
past_key_values=past_key_values,
)

Expand All @@ -267,6 +269,9 @@ def forward(
if past_key_values is not None:
model_inputs += past_key_values

if use_cache_branch is not None:
model_inputs.append(use_cache_branch)

if "labels" in self.input_names:
model_inputs.append(labels)
known_output_shapes.update({"loss": []})
Expand Down Expand Up @@ -302,9 +307,6 @@ def forward(
}

if self.parent_model.use_merged is True:
use_cache_branch, past_key_values, is_dummy = self.prepare_inputs_for_merged(
input_ids, past_key_values
)
onnx_inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy()

if past_key_values is not None:
Expand Down Expand Up @@ -371,7 +373,6 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
) -> Seq2SeqLMOutput:
known_output_shapes = {}
# Flatten the past_key_values
if past_key_values is not None:
past_key_values = tuple(
Expand Down
3 changes: 3 additions & 0 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,8 +589,11 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache_branch: None = None,
**kwargs,
) -> CausalLMOutputWithCrossAttentions:
# adding use_cache_branch in the signature here is just a hack for IO Binding

if past_key_values is None or self.use_cache is False:
outputs = self.decoder(
input_ids=input_ids,
Expand Down
21 changes: 16 additions & 5 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,16 +2121,27 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode

self.assertTrue(torch.equal(outputs_model_merged, outputs_model_not_merged))

@parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}))
@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
)
@require_torch_gpu
@pytest.mark.gpu_test
def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: bool):
model_args = {"test_name": test_name, "model_arch": model_arch, "use_cache": use_cache}
def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool):
model_args = {
"test_name": test_name,
"model_arch": model_arch,
"use_cache": use_cache,
"use_merged": use_merged,
}
self._setup(model_args)

model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForCausalLM.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False)
io_model = ORTModelForCausalLM.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True)
onnx_model = ORTModelForCausalLM.from_pretrained(
self.onnx_model_dirs[test_name], use_cache=use_cache, use_io_binding=False
).to("cuda")
io_model = ORTModelForCausalLM.from_pretrained(
self.onnx_model_dirs[test_name], use_cache=use_cache, use_io_binding=True
).to("cuda")

tokenizer = get_preprocessor(model_id)
tokens = tokenizer(["This is a sample output"] * 2, return_tensors="pt").to("cuda")
Expand Down